diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b3823f8b04443ae0f6afdd445bee262509d20c3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/beam_constraints.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/beam_constraints.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0727ad72c5c6400c1cf06a9c0a5a8abaf6ee682 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/beam_constraints.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/beam_search.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/beam_search.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6f1390dee001d0bff23b74c839fc2a781de4bf1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/beam_search.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/candidate_generator.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/candidate_generator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fe6bf8aed3c9b1544ee9dfed194a6d8dca67f3a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/candidate_generator.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/configuration_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/configuration_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62afa7da64379a34913f566be602548a0a3c6d0b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/configuration_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/flax_logits_process.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/flax_logits_process.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edb0ca394162c80c24df275555fcab2fa7572304 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/flax_logits_process.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/flax_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/flax_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f286ecd930b36a5e79c7656728a4661e9981f5ad Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/flax_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/stopping_criteria.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/stopping_criteria.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c45b55ab4786e5a3e70ccf244425b99ad3e537c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/stopping_criteria.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/streamers.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/streamers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8008cd67d05266050a0c6f9b20a6c127b90baee Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/streamers.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/tf_logits_process.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/tf_logits_process.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd59a3d4adcafc457854d738ceb084782f3fe91f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/tf_logits_process.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/watermarking.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/watermarking.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d35af93cb033d6592506f1f23d69cd7c19fe4a8a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/watermarking.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d6800f7db3550fbadfafcac6e2ae0c81eae53d7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__init__.py @@ -0,0 +1,26 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .cache import PagedAttentionCache +from .continuous_api import ContinuousBatchingManager, ContinuousMixin +from .requests import RequestState, RequestStatus + + +__all__ = [ + "ContinuousBatchingManager", + "ContinuousMixin", + "PagedAttentionCache", + "RequestState", + "RequestStatus", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea783c242787cd9d6d16530981410368016134ef Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/cache.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b557494ab9957eb91e63063a9f1eb25e6cfb7b8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/cache.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/cache_manager.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/cache_manager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cef0cba38e1895f4f5e57fe2c80394c97de28dc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/cache_manager.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/continuous_api.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/continuous_api.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79068ab6ab8412362c438c89be8b183998010c0b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/continuous_api.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/requests.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/requests.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e740a7629d9422997c75a001c1105ef15a14683 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/requests.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/scheduler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/scheduler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f37fb0f1adb11c1ef635968d5bde89275aee8398 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/scheduler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/cache.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..8d6e057be84a7f04f3638bad859af47910ba5ca9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/cache.py @@ -0,0 +1,606 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import deque +from math import floor, gcd, sqrt +from typing import Optional, Union + +import torch + +from ...configuration_utils import PretrainedConfig +from ...generation.configuration_utils import GenerationConfig +from ...utils.metrics import attach_tracer, traced +from .cache_manager import CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator +from .requests import get_device_and_memory_breakdown, logger + + +def group_layers_by_attn_type(config: PretrainedConfig) -> tuple[list[list[int]], list[str]]: + """ + Group layers depending on the attention mix, according to VLLM's hybrid allocator rules: + - Layers in each group need to have the same type of attention + - All groups have the same number of layers + + For a model with the following layer types: ["sliding", "full", "full", "sliding", "full", "full", "full", "full"] + We would get two groups: [0, 3] and [1, 2], [4,5], [6,7]. + """ + # If the config has no layer_type attribute, it means all layers are the same attention type + layer_types = getattr(config, "layer_types", None) + if layer_types is None: + attn_type = "sliding_attention" if getattr(config, "sliding_window", None) is not None else "full_attention" + layer_types = [attn_type for _ in range(config.num_hidden_layers)] + + # We then count the number of layers of each type + layer_counts = {} + for i, layer_type in enumerate(layer_types): + layer_counts[layer_type] = layer_counts.get(layer_type, []) + [i] + + # The size of all groups is the greatest common divisor of the number of layers of each type + group_size = gcd(*[len(indices) for indices in layer_counts.values()]) + + # We then group the layers by type + layer_groups = [] + for layer_type, indices in layer_counts.items(): + for i in range(0, len(indices), group_size): + layer_groups.append(indices[i : i + group_size]) + # And note the layer types + group_types = [layer_types[lg[0]] for lg in layer_groups] + return layer_groups, group_types + + +@attach_tracer() +class PagedAttentionCache: + """ + Manages the cache for a paged attention mechanism, inspired by VLLM's hybrid allocator. The cache relies on making + groups of layers to reduce the complexity of cache management and fragmentation. + + The cache uses a three-level hierarchy: + - Pages: The smallest unit of cache, a page has a size of [num_heads, head_size], which is the space needed to + store the key or value states for one token and one layer. For a model with only full-attention layers, to store + the KV cache of one token, we need `2 * num_layers` pages: key and values each take `num_layers` pages. + Pages are grouped into blocks: + - Blocks: A block is a collection of `block_size` pages, serving as the allocation unit to reduce management + complexity and fragmentation. Cache is allocated and freed block by block, not page by page. One block is + allocated to one layer group, which only has one attention type, like full-attention or sliding-attention. + If all layers in the model have the same attention type, then all layers will be in the same group. There is + more than one group if and only if the model has a mixed attention types, like layers with full-attention and + layers with sliding-attention. + - Cache tensors: The physical supports for the cache. There are as many cache tensors as there are layer in a + layer group, and the shape of the cache tensor is `[num_blocks * block_size, num_heads, head_size]`. + + Grouping layers into groups is useful because when we allocate one block to a group N, the block allocated is the + same for all layers in group N, equivalently it is allocated across all cache tensors. This allows us to + efficiently allocate and free blocks, and to efficiently read and write key and value states. + + For instance, imagine we have 8 blocks of cache and a model with two layer groups: a full-attention group with 3 + layers and a sliding-attention group with 3 layers. At creation time, the physical cache tensors look like this: + + cache_tensor_0: □ □ □ □ □ □ □ □ + cache_tensor_1: □ □ □ □ □ □ □ □ + cache_tensor_2: □ □ □ □ □ □ □ □ + + where □ means the blocks is not allocated to any layer group yet. We have 3 cache tensors because there are + 3 layers per group. + We allocate 1 block to each group, after allocation, the cache tensors look like this: + + cache_tensor_0: ✖ ◉ □ □ □ □ □ □ + cache_tensor_1: ✖ ◉ □ □ □ □ □ □ + cache_tensor_2: ✖ ◉ □ □ □ □ □ □ + + where ✖ means the block is allocated to the full-attention group, and ◉ means the block is allocated to the + sliding-attention group. + Now, if we continue to generate, and the sliding window has been reached, we only need to allocate a new block + for the full-attention group, and the cache tensors look like this: + + cache_tensor_0: ✖ ◉ ✖ □ □ □ □ □ + cache_tensor_1: ✖ ◉ ✖ □ □ □ □ □ + cache_tensor_2: ✖ ◉ ✖ □ □ □ □ □ + + And after further generation, when we need a new block allocated: + + cache_tensor_0: ✖ ◉ ✖ ✖ □ □ □ □ + cache_tensor_1: ✖ ◉ ✖ ✖ □ □ □ □ + cache_tensor_2: ✖ ◉ ✖ ✖ □ □ □ □ + + This would not have been possible if all layers were in the same group: we would have had to allocate a new block + for the sliding-attention group, although it is not needed. + """ + + # TODO: this init is quite long, maybe a refactor is in order + def __init__( + self, + config: PretrainedConfig, + generation_config: GenerationConfig, + device: torch.device, + dtype: torch.dtype = torch.float16, + layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, + tp_size: Optional[int] = None, + ) -> None: + """Initialize a paged attention cache for efficient memory usage. + + Args: + config: Model configuration + generation_config: Generation configuration containing cache parameters + device: Device for the cache tensors + dtype: Data type of the cache + layer_device_map: Optional mapping of layer indices to devices + tp_size: Tensor parallelism size + """ + self.config = config + self.dtype = dtype + self.device = device + + # Extract model dimensions + kv_heads = getattr(config, "num_key_value_heads", None) + self.num_key_value_heads: int = kv_heads if kv_heads is not None else config.num_attention_heads + head_dim = getattr(config, "head_dim", None) + self.head_dim: int = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads + + # Extract cache dimensions + self.block_size = getattr(generation_config, "block_size", 32) + + # Group layers depending on the attention mix + layer_groups, group_types = group_layers_by_attn_type(config) + group_size = len(layer_groups[0]) + self.num_groups = len(layer_groups) + + self.sliding_windows = {} + self.layer_index_to_group_indices = {} + for i, group in enumerate(layer_groups): + sliding_window = config.sliding_window if group_types[i] == "sliding_attention" else 1 + for j, layer in enumerate(group): + self.layer_index_to_group_indices[layer] = (i, j) + self.sliding_windows[layer] = sliding_window + + # Handle TP (or dont) + if tp_size is not None and tp_size > 1: + if self.num_key_value_heads % tp_size != 0: + raise ValueError( + f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}." + ) + # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. + # self.num_key_value_heads //= tp_size # TODO: why is this commented out? + + # Infer number of blocks and max batch tokens + page_size = self.head_dim * self.num_key_value_heads + + if getattr(config, "attn_implementation", None) == "paged_attention": + num_attention_masks = 0 + else: + # TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))` + num_attention_masks = 2 if "sliding_attention" in group_types else 1 + + memory_handler = PagedAttentionMemoryHandler( + block_size=self.block_size, + page_size=page_size, + num_groups=self.num_groups, + group_size=group_size, + peak_activation_per_token=(config.hidden_size + config.vocab_size), + num_attention_masks=num_attention_masks, + ) + num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens( + num_blocks=getattr(generation_config, "num_blocks", None), + max_batch_tokens=getattr(generation_config, "max_batch_tokens", None), + max_memory_percent=getattr(generation_config, "max_memory", 0.9), + cache_dtype=self.dtype, + ) + + # Add the inferred attributes to the class + self.num_blocks = num_blocks + self.max_batch_tokens = max_batch_tokens + logger.info( + f"PagedAttentionCache initialized with {self.num_blocks = }, {self.block_size = }, {page_size = }, " + f"{self.max_batch_tokens = } {num_attention_masks = }" + ) + + # Initialize the cache + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + # We add one extra token to the cache to handle padding and generally discard unwanted tokens + self.cache_shape = (num_blocks * self.block_size + 1, self.num_key_value_heads, self.head_dim) + for _ in range(group_size): + new_layer_key_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device) + new_layer_value_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }") + + # Block management data structures + self._free_blocks = deque(range(num_blocks)) + self.group_cache_managers: list[CacheAllocator] = [] + for i, group_type in enumerate(group_types): + if group_type == "full_attention": + cm = FullAttentionCacheAllocator(i, self.block_size) + elif group_type == "sliding_attention": + cm = SlidingAttentionCacheAllocator(i, self.block_size, config.sliding_window) + else: + raise ValueError(f"Invalid group type: {group_type}") + self.group_cache_managers.append(cm) + + @traced + def allocate_blocks(self, n_blocks: int, request_id: str) -> int: + """Allocate cache blocks across all layer groups for a given request. Actual allocation is done by the cache + managers, and this method only returns the maximum number of blocks actually allocated across all managers.""" + max_allocated = 0 + for cm in self.group_cache_managers: + allocated = cm.allocate_blocks(n_blocks, request_id, self._free_blocks) + if allocated is None: + return None + max_allocated = max(max_allocated, allocated) + return max_allocated + + @traced + def free_blocks(self, request_id: str) -> None: + """Free all allocated cache blocks for a given request across all layer groups. Actual deallocation is done + by the cache managers.""" + for cm in self.group_cache_managers: + cm.free_blocks(request_id, self._free_blocks) + + def get_num_free_blocks(self) -> int: + """Get the current number of unallocated blocks available for new requests.""" + return len(self._free_blocks) + + @traced + def extend_read_indices( + self, request_id: str, past_length: int, query_length: int, read_index: list[list[int]] + ) -> None: + """Retrieve physical cache indices for reading KV states in the cache across all layer groups. This method + coordinates with all cache managers to build the complete set of read indices needed for attention computation. + """ + for cm, read_indices in zip(self.group_cache_managers, read_index): + indices = cm.get_read_indices(request_id, past_length, query_length) + read_indices.extend(indices) + + @traced + def extend_write_indices( + self, request_id: str, past_length: int, query_length: int, write_index: list[list[int]] + ) -> None: + """Retrieve physical cache indices for writing new KV states to the cache across all layer groups. This method + coordinates with all cache managers to build the complete set of write indices needed to store computed KV + states.""" + for cm, write_indices in zip(self.group_cache_managers, write_index): + indices = cm.get_write_indices(request_id, past_length, query_length) + write_indices.extend(indices) + + @traced + def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> dict[str, int]: + """Retrieve the key sequence length for the given request_id across all layer types. Returns a dictionary of + layer types to their corresponding key sequence lengths.""" + seqlens_k = {} + for cm in self.group_cache_managers: + attn_type, seqlen_k = cm.get_seqlens_k(request_id, past_length, query_length) + seqlens_k[attn_type] = seqlen_k + return seqlens_k + + @traced + def update( + self, + key_states: torch.Tensor, # shape [1, num_kv_heads, seqlen_kv, head_dim] + value_states: torch.Tensor, # shape [1, num_kv_heads, seqlen_kv, head_dim] + layer_idx: int, + read_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_kv + past_length] + write_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_q] + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: # shape [seqlen_kv + past_length, num_kv_heads, head_dim] + """Update the cache with new key-value states for a specific layer. This method writes new KV states to the + appropriate cache locations. The behavior differs based on the layer's attention type: + + - Full attention: New KV states are written to cache, then complete sequence is read from cache + - Sliding window: Old KV is read from cache along with extra spaces for the new KV, then new KV is written to + cache. This is because new KV might overwrite the old KV, so we need to read the old KV first. + + Returns the complete KV states (cached + new) for attention computation. + """ + # Retrieve the layer read and write indices, and if there is a sliding window + group_idx, layer_idx_in_group = self.layer_index_to_group_indices[layer_idx] + layer_read_index = read_index[group_idx] + layer_write_index = write_index[group_idx] + # Select the correct cache + k_cache = self.key_cache[layer_idx_in_group] + v_cache = self.value_cache[layer_idx_in_group] + # Transpose the key and value states to match the cache shape, after which shape is [seqlen_kv, num_kv_heads, head_dim] + key_states = key_states.transpose(1, 2).squeeze(0) + value_states = value_states.transpose(1, 2).squeeze(0) + + # Case: full attention + sliding_window = self.sliding_windows[layer_idx] + if sliding_window == 1: + k_cache[layer_write_index, :, :] = key_states + v_cache[layer_write_index, :, :] = value_states + key_states_with_cache = k_cache[layer_read_index, :, :] + value_states_with_cache = v_cache[layer_read_index, :, :] + + # Case: sliding window -- we need to be careful of read/write order because of chunked prefill, because it's + # the only case where you may write over cache you need to use + else: + # Add the cache to the key and value states + mask = layer_read_index == -1 # TODO: can this can be efficiently precomputed? + key_states_with_cache = k_cache[layer_read_index, :, :] + key_states_with_cache[mask] = key_states + value_states_with_cache = v_cache[layer_read_index, :, :] + value_states_with_cache[mask] = value_states + # Write new KV values to the cache + k_cache[layer_write_index, :, :] = key_states + v_cache[layer_write_index, :, :] = value_states + + # Return the new KV values + return key_states_with_cache, value_states_with_cache + + +# TODO: rework computation with the groups and their sizes +class PagedAttentionMemoryHandler: + """A helper class to determine the best number of pages and maximum number of tokens per batch for the paged + attention cache, providing automatic sizing based on available GPU memory. + The helper works using the number of pages, which is tied to the number of blocks by: + num_blocks = num_pages // block_size + + The memory footprint consists of three main components: + - Cache memory: the space needed to store the cache tensors: + 2 * layer_group_size * [num_pages, page_size] * cache_dtype + - Activation memory: the space temporarily taken by the largest activation during the model forward pass: + peak_activation_per_token * max_tokens_per_batch * activation_dtype_size + - Static tensors: the space taken by the input/output buffers and metadata tensors for batch processing, sum of: + - inputs_ids + outputs_ids + position_ids + logits_indices: 4 * max_tokens_per_batch * int32_size + - attention_mask: num_attention_masks * num_pages * max_tokens_per_batch * activation_dtype_size + - cumulative_seqlens_q + cumulative_seqlens_k: (1 + 2) * max_tokens_per_batch * int32_size + - write_index_tensor: num_groups * max_tokens_per_batch * int32_size + - read_index_tensor: num_groups * (num_pages + max_tokens_per_batch) * int32_size + + The handler can operate in three modes: + 1. Auto-sizing: Determines both number of pages and maximum number of tokens per batch using quadratic optimization + 2. Fixed cache: Calculates max batch tokens given a fixed number of pages + 3. Fixed batch: Calculates number of pages given a fixed maximum batch size + + """ + + _activation_dtype = torch.bfloat16 + _input_dtype = torch.int32 + _upper_bound_max_batch_tokens = 256 + _upper_bound_num_blocks = 4096 + + def __init__( + self, + block_size: int, + page_size: int, + num_groups: int, + group_size: int, + peak_activation_per_token: int, + num_attention_masks: int, + ) -> None: + """Initialize the memory handler with the parameters that cannot be automatically inferred. + + Args: + block_size: Size of the cache blocks + page_size: Size of the cache pages + num_groups: Number of layer groups + group_size: Number of layers per layer group + peak_activation_per_token: Maximum size of activation tensor per token, = hidden_size + vocab_size + num_attention_masks: Number of attention masks, 0 if no attention mask is used, 2 if hybrid model, else 1 + """ + self.block_size = block_size + self.page_size = page_size + self.num_groups = num_groups + self.group_size = group_size + self.peak_activation_per_token = peak_activation_per_token + self.num_attention_masks = num_attention_masks + + @staticmethod + def get_available_memory(max_memory_percent: float = 1.0) -> int: + """Calculate available GPU memory for cache allocation, accounting for already allocated tensors. + This method queries the current memory state and applies the specified percentage limit to determine + how much memory can be safely used for the paged attention cache. + + Args: + max_memory_percent: Fraction of available memory to use (0.0-1.0). 1.0 means use all available memory. + + Returns: + int: Available memory in bytes for cache allocation + """ + _, total, reserved, allocated = get_device_and_memory_breakdown() + available_memory = total - max(allocated, reserved) + available_memory = int(available_memory * max_memory_percent) + return available_memory + + def infer_num_blocks_and_max_batch_tokens( + self, + num_blocks: Optional[int] = None, + max_batch_tokens: Optional[int] = None, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + ) -> tuple[int, int]: + """Determine optimal number of blocks and maximum number of tokens per batch based on available memory and + constraints. Check the class docstring for more details. Naming the number of pages as N and the maximum number + of tokens per batch as M, the equation solved is: + + available_memory = sum([ + MN * num_attention_masks * activation_dtype_size, + 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group), + M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group), + ]) + + where we already simplified int32_size = 4. + """ + # If neither num_blocks nor max_batch_tokens are provided, we use a second-order polynomial + if num_blocks is None and max_batch_tokens is None: + num_blocks, max_batch_tokens = self.compute_num_blocks_and_max_batch_tokens( + max_memory_percent, cache_dtype + ) + # If only num_blocks is provided, we infer the max_batch_tokens + elif num_blocks is not None and max_batch_tokens is None: + max_batch_tokens = self.compute_max_batch_tokens(num_blocks, max_memory_percent, cache_dtype) + # If only max_batch_tokens is provided, we infer the num_blocks + elif max_batch_tokens is not None and num_blocks is None: + num_blocks = self.compute_num_blocks(max_batch_tokens, max_memory_percent, cache_dtype) + + # We check if the memory footprint is too large in all cases + available_memory = self.get_available_memory(max_memory_percent) + memory_footprint = self.compute_memory_footprint( + max_batch_tokens=max_batch_tokens, + num_blocks=num_blocks, + cache_dtype=cache_dtype, + ) + if memory_footprint > available_memory: + raise MemoryError(f"Memory footprint {memory_footprint} is more than available memory {available_memory}") + return num_blocks, max_batch_tokens + + def compute_num_blocks_and_max_batch_tokens( + self, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + m: float = 0.01, + ) -> tuple[int, int]: + """Calculate optimal number of blocks and maximum number of tokens per batch using quadratic optimization when + neither is fixed. This method assumes a relationship M = m * N where m is a small ratio below 1 and solves the + resulting quadratic equation to find the optimal N that maximizes utilization within memory constraints. m is + the amount of cache we can fill with one batch: m=0.01 means a batch fills at most 1% of the cache. The equation + to solve is: + + available_memory = sum([ + m * N^2 * num_attention_masks * activation_dtype_size, + 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group), + m * N * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group), + ]) + """ + cache_memory = self.get_available_memory(max_memory_percent) + logger.info(f"Cache memory: {cache_memory}") + + # Compute second-degree polynomial coefficients + a = m * self.num_attention_masks * self._activation_dtype.itemsize + b = 2 * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups) + b += m * (self.peak_activation_per_token * self._activation_dtype.itemsize + 28 + 4 * self.num_groups) + c = -cache_memory + logger.debug(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }") + + # Compute discriminant and greatest solution + discriminant = b**2 - 4 * a * c + if discriminant < 0: + raise ValueError(f"Discriminant is negative: {discriminant = }") + greatest_solution = (-b + sqrt(discriminant)) / (2 * a) + if greatest_solution < 0: + raise ValueError(f"Greatest solution is negative: {greatest_solution = }") + + # Infer number of blocks and max batch tokens + num_pages = floor(greatest_solution) + num_blocks = num_pages // self.block_size + if num_blocks > self._upper_bound_num_blocks: + logger.info(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }") + num_blocks = self._upper_bound_num_blocks + max_batch_tokens = int(greatest_solution * m) + if max_batch_tokens > self._upper_bound_max_batch_tokens: + logger.info(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }") + max_batch_tokens = self._upper_bound_max_batch_tokens + return num_blocks, max_batch_tokens + + def compute_max_batch_tokens( + self, + num_blocks: int, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + ) -> int: + """Calculate maximum batch tokens M given a fixed number of cache blocks. The formula for M is given by: + + M = (available_memory - 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group)) + / (activation_dtype_size * (N * num_attention_masks + peak_activation_per_token) + 28 + 4 * num_group) + """ + cache_memory = self.get_available_memory(max_memory_percent) + num_pages = num_blocks * self.block_size + # Compute numerator + num = cache_memory + num -= 2 * num_pages * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups) + # Compute denominator + denum = self._activation_dtype.itemsize * ( + num_pages * self.num_attention_masks + self.peak_activation_per_token + ) + denum += 28 + 4 * self.num_groups + # Compute max batch tokens and return + max_batch_tokens = floor(num / denum) + if max_batch_tokens > self._upper_bound_max_batch_tokens: + logger.info(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }") + max_batch_tokens = self._upper_bound_max_batch_tokens + return max_batch_tokens + + def compute_num_blocks( + self, + max_batch_tokens: int, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + ) -> int: + """Calculate number of cache blocks N given a fixed maximum token per token M. The formula for N is given by: + + N = (available_memory - M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group)) + / (2 * (layer_group_size * page_size * cache_dtype + 2 * num_group) + M * (num_attention_masks * activation_dtype_size)) + """ + cache_memory = self.get_available_memory(max_memory_percent) + # Compute numerator + num = cache_memory + num -= max_batch_tokens * self.peak_activation_per_token * self._activation_dtype.itemsize + num -= max_batch_tokens * (28 + 4 * self.num_groups) + # Compute denominator + denum = 2 * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups) + denum += max_batch_tokens * (self.num_attention_masks * self._activation_dtype.itemsize) + denum += max_batch_tokens * self._activation_dtype.itemsize + # Compute cache size and return number of blocks + num_pages = floor(num / denum) + num_blocks = num_pages // self.block_size + if num_blocks > self._upper_bound_num_blocks: + logger.info(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }") + num_blocks = self._upper_bound_num_blocks + return num_blocks + + def compute_memory_footprint( + self, + num_blocks: Optional[int] = None, + max_batch_tokens: Optional[int] = None, + cache_dtype: torch.dtype = torch.float16, + ) -> tuple[int, int, int]: + """Calculate the memory footprint breakdown for a given number of blocks and maximum batch tokens. The memory + footprint is given by: + + available_memory = sum([ + MN * num_attention_masks * activation_dtype_size, + 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group), + M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group), + ]) + but is broken down below. + """ + num_pages = num_blocks * self.block_size + + cache_memory_footprint = 2 * self.group_size * num_pages * self.page_size * cache_dtype.itemsize + + activation_memory_footprint = self.peak_activation_per_token * self._activation_dtype.itemsize + activation_memory_footprint *= max_batch_tokens + + inputs_outputs_positions_and_logits_memory_footprint = 4 * max_batch_tokens * 4 # second 4 is for int32 size + + attention_memory_footprint = self.num_attention_masks * self._activation_dtype.itemsize + attention_memory_footprint *= num_pages * max_batch_tokens + + cumulative_seqlens_memory_footprint = 3 * max_batch_tokens * 4 # 4 is for int32 size + + write_index_memory_footprint = self.num_groups * max_batch_tokens * 4 # 4 is for int32 size + read_index_memory_footprint = self.num_groups * (num_pages + max_batch_tokens) * 4 # 4 is for int32 size + + total_memory_footprint = sum( + [ + cache_memory_footprint, + activation_memory_footprint, + inputs_outputs_positions_and_logits_memory_footprint, + attention_memory_footprint, + cumulative_seqlens_memory_footprint, + write_index_memory_footprint, + read_index_memory_footprint, + ] + ) + return total_memory_footprint diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/cache_manager.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/cache_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..7e2d4f2b553278f3dcbf3b7d4a46b6a0f2f9f0dd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/cache_manager.py @@ -0,0 +1,226 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from collections import deque +from math import ceil +from typing import Optional + +from .requests import logger + + +class CacheAllocator(ABC): + """Abstract base class for cache managers. Cache managers keep track of per-request cache allocations, determine + when a new physical block needs to be allocated and compute physical indices for reading or writing to the cache.""" + + _index: int + _block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request + + @abstractmethod + def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]: + """Allocates n_blocks for a given request_id. Returns the num of blocks allocated if successful and None + otherwise.""" + pass + + def free_blocks(self, request_id: str, free_blocks: deque[int]) -> None: + """Frees all blocks associated with a request_id.""" + if request_id in self._block_table: + blocks_to_free = self._block_table.pop(request_id) + free_blocks.extend(blocks_to_free) + else: + logger.warning( + f"CacheAllocator {self._index} attempted to free blocks for non-existent request_id: {request_id}" + ) + + @abstractmethod + def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]: + """Returns the physical indices of where to read request_id's cache in the cache tensor.""" + pass + + @abstractmethod + def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]: + """Returns the physical indices of where to write request_id's cache in the cache tensor.""" + pass + + @abstractmethod + def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]: + """Returns the attention type of the cache allocator and the key sequence length for the given request_id.""" + pass + + +class FullAttentionCacheAllocator(CacheAllocator): + """Cache manager for a group of full attention layers.""" + + def __init__(self, index: int, block_size: int) -> None: + """Initializes the cache manager for a group of full attention layers. + Args: + - index: the index of the associated layer group + - block_size: the size of the blocks in the cache + """ + self._index = index + self.block_size = block_size + self._block_table = {} + + def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]: + """Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None + otherwise. For group of full attention layers, we always allocate the number of requested blocks.""" + if len(free_blocks) < n_blocks: + return None + if request_id not in self._block_table: + self._block_table[request_id] = [] + self._block_table[request_id].extend(free_blocks.popleft() for _ in range(n_blocks)) + return n_blocks + + def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]: + """Returns the physical indices of where to read request_id's cache. For a group of full attention layers, we + first write the new cache to the cache tensor and then read the entire cache from the beginning to the end.""" + # Retrieve the block table for the request and raise an error if it doesn't exist + block_table = self._block_table.get(request_id) + if block_table is None: + raise ValueError(f"No block table found for request {request_id}") + # Compute the physical indices + physical_indices = [] + for i in range(past_length + query_length): + block_idx = i // self.block_size + block_offset = i % self.block_size + physical_index = block_table[block_idx] * self.block_size + block_offset + physical_indices.append(physical_index) + return physical_indices + + def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]: + """Returns the physical indices for writing to the cache. For a group of full attention layers, we write the new + cache as a continuation of the existing cache for the same request.""" + block_table = self._block_table.get(request_id) + if block_table is None: + raise ValueError(f"No block table found for request {request_id}") + # Compute the physical indices + physical_indices = [] + for i in range(past_length, past_length + query_length): + block_idx = i // self.block_size + block_offset = i % self.block_size + physical_index = block_table[block_idx] * self.block_size + block_offset + physical_indices.append(physical_index) + return physical_indices + + def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]: + """Returns the attention type of the cache allocator and the key sequence length for the given request_id.""" + seqlens_k = past_length + query_length + return "full_attention", seqlens_k + + +class SlidingAttentionCacheAllocator(CacheAllocator): + """Cache manager for sliding window attention layers.""" + + def __init__(self, index: int, block_size: int, sliding_window: int) -> None: + """Initializes the cache manager for a group of sliding window attention layers. + Args: + - index: the index of the associated layer group + - block_size: the size of the blocks in the cache + - sliding_window: the size of the sliding window + """ + self._index = index + self.block_size = block_size + self.sliding_window = sliding_window + self._max_blocks_per_request = ceil(self.sliding_window / self.block_size) + self._block_table = {} + + def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]: + """Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None + otherwise. For group of sliding window attention layers, we only allocate up to the point where we can fit an + entire sliding window in the cache tensor.""" + if request_id not in self._block_table: + self._block_table[request_id] = [] + # Early return if we are already at the max number of blocks per request + already_allocated = len(self._block_table[request_id]) + if already_allocated == self._max_blocks_per_request: + return 0 + # Compute actual number of blocks to allocate + after_allocation = min(already_allocated + n_blocks, self._max_blocks_per_request) + actual_n_blocks = after_allocation - already_allocated + # Classic allocation + if len(free_blocks) < actual_n_blocks: + return None + self._block_table[request_id].extend(free_blocks.popleft() for _ in range(actual_n_blocks)) + return actual_n_blocks + + def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]: + """Returns the physical indices of where to read request_id's cache in the cache tensor. + For a group of sliding window attention layers, we read from the cache tensor before writing on it, because the + new cache can overwrite the old one. To form the cache + new key / values states, we read the at most + sliding_window - 1 cache page and then manually add the new key / values states after. Hence the -1 indices + which indicate where to store the new key or values indices.""" + # Retrieve the block table for the request and raise an error if it doesn't exist + block_table = self._block_table.get(request_id) + if block_table is None: + raise ValueError(f"No block table found for request {request_id}") + # Apply sliding window + start_index = 0 if past_length < self.sliding_window else past_length % self.sliding_window + cache_length = min(past_length, self.sliding_window - 1) + # Compute the physical indices + physical_indices = [] + for i in range(start_index, start_index + cache_length): + i %= self.sliding_window + block_idx = i // self.block_size + block_offset = i % self.block_size + physical_index = block_table[block_idx] * self.block_size + block_offset + physical_indices.append(physical_index) + return physical_indices + [-1] * query_length + + def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]: + """Returns the physical indices of where to write request_id's cache in the cache tensor. For a group of + sliding window attention layers, we write the new cache in rolling-buffer kind of way: if we reach the end of + the allocated physical cache, we start writing from the beginning of the physical cache again.""" + # Retrieve the block table for the request and raise an error if it doesn't exist + block_table = self._block_table.get(request_id) + if block_table is None: + raise ValueError(f"No block table found for request {request_id}") + # Apply sliding window + start_index = past_length % self.sliding_window + cache_length = min(query_length, self.sliding_window) + padding_length = query_length - cache_length + # Compute the physical indices + physical_indices = [] + for i in range(start_index, start_index + cache_length): + i %= self.sliding_window + block_idx = i // self.block_size + block_offset = i % self.block_size + physical_index = block_table[block_idx] * self.block_size + block_offset + physical_indices.append(physical_index) + if padding_length > 0: + physical_indices = [-1] * padding_length + physical_indices + return physical_indices + + def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]: + """Returns the attention type of the cache allocator and the key sequence length for the given request_id.""" + seqlens_k = query_length + min(past_length, self.sliding_window - 1) + return "sliding_attention", seqlens_k + + +# TODO: test the impact of this +# def get_read_indices(self, request_id: str, past_length: int) -> list[int]: +# # Retrieve the block table for the request and raise an error if it doesn't exist +# block_table = self._block_table.get(request_id) +# if block_table is None: +# raise ValueError(f"No block table found for request {request_id}") +# # Compute the physical indices +# physical_indices = [] +# n_left = past_length +# for block_idx in block_table: +# block_physical_index = block_idx * self.block_size +# pages_used = min(self.block_size, n_left) +# physical_indices.extend(block_physical_index + i for i in range(pages_used)) +# n_left -= pages_used +# if n_left == 0: +# return physical_indices +# raise ValueError(f"Request {request_id} required too many indices: {past_length = } and {len(block_table) = }") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/continuous_api.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/continuous_api.py new file mode 100644 index 0000000000000000000000000000000000000000..0d1801fa163e137f69e128cebbaa36877eaaa28a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/continuous_api.py @@ -0,0 +1,1047 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import queue +import threading +from dataclasses import dataclass +from functools import partial +from itertools import count +from time import perf_counter +from typing import Optional, Union + +import torch +from torch import nn +from tqdm import tqdm + +from ...configuration_utils import PretrainedConfig +from ...generation.configuration_utils import GenerationConfig +from ...utils.logging import logging +from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced +from .cache import PagedAttentionCache +from .requests import GenerationOutput, RequestState, RequestStatus, get_device_and_memory_breakdown, logger +from .scheduler import SCHEDULER_MAPPING, FIFOScheduler, Scheduler + + +def build_attention_mask( + attention_mask: torch.Tensor, + cumulative_seqlens_q: torch.Tensor, + cumulative_seqlens_k: torch.Tensor, + sliding_window: int = 1, +) -> None: + """Builds an attention mask inplace using the cumulative seqlens of the query and key. If given a sliding window, it + will also apply a sliding window mask on top. The attention mask is not boolean, it uses zeroes and -inf (or its + equivalent) so it's more of an attention score bias tensor. + The attention mask is a block-diagonal matrix, with each block an attention mask for a single query-key pair. + Each of those block is built from a causal mask and, if there is a sliding window, a sliding window mask. + + An example is represented below, with seqlen_k = 8, seqlen_q = 4 and sliding_window = 6: + + CAUSAL MASK: + + █ █ █ █ █ ░ ░ ░ + █ █ █ █ █ █ ░ ░ + █ █ █ █ █ █ █ ░ + █ █ █ █ █ █ █ █ + + SLIDING WINDOW MASK: + ┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 8 - 4 - 6 = -2 offset to the right + <─┴─> + ░ █ | █ █ █ █ █ █ █ █ + ░ ░ | █ █ █ █ █ █ █ █ + ░ ░ | ░ █ █ █ █ █ █ █ + ░ ░ | ░ ░ █ █ █ █ █ █ + + ATTENTION MASK (sum of causal and sliding window masks): + + █ █ █ █ █ ░ ░ ░ + █ █ █ █ █ █ ░ ░ + ░ █ █ █ █ █ █ ░ + ░ ░ █ █ █ █ █ █ + + Another example with seqlen_k = 5, seqlen_q = 3 and sliding_window = 2: + + CAUSAL MASK: + + █ █ █ ░ ░ + █ █ █ █ ░ + █ █ █ █ █ + + SLIDING WINDOW MASK: + ┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 5 - 3 - 2 = 0 offset to the right + <┴> + | ░ █ █ █ █ + | ░ ░ █ █ █ + | ░ ░ ░ █ █ + + ATTENTION MASK (sum of causal and sliding window masks): + + ░ █ █ ░ ░ + ░ ░ █ █ ░ + ░ ░ ░ █ █ + + """ + min_value = torch.finfo(attention_mask.dtype).min + for i in range(len(cumulative_seqlens_q) - 1): + seqlen_q = cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i] + seqlen_k = cumulative_seqlens_k[i + 1] - cumulative_seqlens_k[i] + if seqlen_q < seqlen_k and seqlen_q >= 1: + causal_diagonal = seqlen_k - seqlen_q + 1 + else: + causal_diagonal = 1 + query_range = slice(cumulative_seqlens_q[i], cumulative_seqlens_q[i + 1]) + key_range = slice(cumulative_seqlens_k[i], cumulative_seqlens_k[i + 1]) + # Apply causal mask + minus_inf = torch.full( + attention_mask[..., query_range, key_range].shape, + min_value, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + masked = torch.triu(minus_inf, diagonal=causal_diagonal) + # Apply sliding window mask if needed + if sliding_window > 1: + sliding_diagonal = seqlen_k - seqlen_q - sliding_window + masked += torch.tril(minus_inf, diagonal=sliding_diagonal) + # Replace in attention mask + attention_mask[..., query_range, key_range] = masked + + +@dataclass +class PagedAttentionArgs: + input_ids: torch.Tensor + attention_mask: Optional[torch.Tensor] + position_ids: torch.Tensor + cumulative_seqlens_q: torch.Tensor + cumulative_seqlens_k: torch.Tensor + max_seqlen_q: int + max_seqlen_k: int + write_index: list[torch.Tensor] + read_index: list[torch.Tensor] + logits_indices: torch.Tensor + cache: PagedAttentionCache + use_cache: bool = False + + +# Continuous Batch Processor (Internal Logic) +@attach_tracer() +class ContinuousBatchProcessor: + def __init__( + self, + cache: PagedAttentionCache, + config: PretrainedConfig, + generation_config: GenerationConfig, + input_queue: queue.Queue, + output_queue: queue.Queue, + stop_event: threading.Event, + model_device: torch.device, + model_dtype: torch.dtype, + scheduler: Scheduler, + streaming: bool = False, + manual_eviction: bool = False, + slice_inputs: bool = True, # TODO: There should be an heuristic to decide on slicing, compile, cuda graphs... + ) -> None: + """Initialize the continuous batch processor. + + Args: + cache: A [`PagedAttentionCache`] object + config: The model configuration + generation_config: The generation configuration + input_queue: Queue for incoming requests + output_queue: Queue for outgoing results + stop_event: Event to signal processing should stop + model_device: Device for model inputs/outputs + model_dtype: Data type for model inputs/outputs + scheduler: The [`Scheduler`] to use + streaming: Whether to stream tokens as they're generated + manual_eviction: Whether to manually evict blocks from the cache + slice_inputs: Whether to slice the inputs to the model + """ + self.cache = cache + self.config = config + self.generation_config = generation_config + self.input_queue = input_queue + self.output_queue = output_queue + self.stop_event = stop_event + self.model_device = model_device + self.model_dtype = model_dtype + self.scheduler = scheduler + self.streaming = streaming + self.manual_eviction = manual_eviction + self.slice_inputs = slice_inputs + + # Retrieve the size of the sliding window if there is one + self.sliding_window = 1 if getattr(config, "sliding_window", None) is None else config.sliding_window + + self.requests_in_batch: list[RequestState] = [] + + # Set up metrics collector + self.max_batch_tokens = cache.max_batch_tokens + self.metrics = ContinuousBatchProcessorMetrics(cache.max_batch_tokens) + + # Setup static tensors + self.total_query_length = 0 + self.total_key_length = 0 + self.total_batch_size = 0 + self.setup_static_tensors(cache.num_groups) + + @traced(standalone=True) + def setup_static_tensors(self, num_groups: int) -> None: + T = self.max_batch_tokens + num_pages = self.cache.num_blocks * self.cache.block_size + self.tensor_metadata = {"dtype": torch.int32, "device": self.model_device} + + # Some tensors always have the same shape regardless of the model + self.input_ids = torch.empty((1, T), **self.tensor_metadata) + self.position_ids = torch.empty((1, T), **self.tensor_metadata) + self.cumulative_seqlens_q = torch.empty((T + 1,), **self.tensor_metadata) + self.max_seqlen_q = 0 + self.logits_indices = torch.empty((T,), **self.tensor_metadata) + self.output_ids = torch.empty((1, T), **self.tensor_metadata) + + # For some kwargs, we have a dict of tensors with as many items as there are attention types + layer_types = getattr(self.config, "layer_types", None) + if layer_types is None: + sliding_window = getattr(self.config, "sliding_window", 1) + layer_types = ["full_attention"] if sliding_window in [1, None] else ["sliding_attention"] + layer_types = list(set(layer_types)) + + self.cumulative_seqlens_k = { + layer_type: torch.empty((T + 1), **self.tensor_metadata) for layer_type in layer_types + } + self.max_seqlen_k = dict.fromkeys(layer_types, 0) + + if self.return_attention_mask(): + attn_mask_kwargs = { + "size": (1, 1, T, num_pages + T), + "dtype": self.model_dtype, + "device": self.model_device, + } + self.attention_mask = {layer_type: torch.empty(**attn_mask_kwargs) for layer_type in layer_types} + else: + self.attention_mask = None + + # For other kwargs, we need a list of tensors with as many tensors as there are groups + self.write_index_storage = [torch.empty((T,), **self.tensor_metadata) for _ in range(num_groups)] + self.read_index_storage = [torch.empty((num_pages + T), **self.tensor_metadata) for _ in range(num_groups)] + # For read index, the +T is because there are -1 for seqlen_q when model uses a sliding window + + # After allocating empty tensors, we reset them to the right value + self.reset_static_tensors(full_reset=True) + + def return_attention_mask(self) -> bool: + return self.config._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call + + @traced + @torch.no_grad() + def reset_static_tensors(self, full_reset: bool = False): + """Reset static tensors for the next batch. In between batches, reset only the parts that were used in the last + batch, but for initialisation, we can reset everything using the (full_reset) flag.""" + # Compute the slice to reset + if full_reset or not self.slice_inputs: + q_len = self.write_index_storage[0].size(-1) + k_len = self.read_index_storage[0].size(-1) + b_size = self.write_index_storage[0].size(0) + else: + q_len = self.total_query_length + k_len = self.total_key_length + b_size = self.total_batch_size + + # Reset the attributes that always have the same shape + self.input_ids[:, :q_len].zero_() + self.position_ids[:, :q_len].zero_() + self.cumulative_seqlens_q[: b_size + 1].zero_() + self.max_seqlen_q = 0 + self.logits_indices[:q_len].fill_(-1) + self.output_ids[:, :q_len].fill_(-1) + + # Reset the attributes that are either tensors or dict of tensors + for layer_type in self.cumulative_seqlens_k: + self.cumulative_seqlens_k[layer_type][: b_size + 1].zero_() + self.max_seqlen_k[layer_type] = 0 + if self.attention_mask is not None: + self.attention_mask[layer_type][:, :, :q_len, :k_len].fill_(torch.finfo(self.model_dtype).min) + + # Reset the attributes that are lists of tensors + for i in range(self.cache.num_groups): + self.write_index_storage[i][:q_len].fill_(-1) + self.read_index_storage[i][: q_len + k_len].fill_(-1) + + def get_model_kwargs(self) -> PagedAttentionArgs: + """Get model keyword arguments for the current batch.""" + # Compute the slice to return + q_len = self.total_query_length if self.slice_inputs else self.write_index_storage[0].size(-1) + b_size = self.total_batch_size if self.slice_inputs else self.cumulative_seqlens_q.size(-1) - 1 + + # Prepare the kwargs, the attributes that are either tensors or dict of tensors are initialized to empty dicts + kwargs = { + "input_ids": self.input_ids[:, :q_len], + "position_ids": self.position_ids[:, :q_len], + "cu_seq_lens_q": self.cumulative_seqlens_q[: b_size + 1], + "max_seqlen_q": self.max_seqlen_q, + "logits_indices": self.logits_indices[:q_len], + "cu_seq_lens_k": {}, + "max_seqlen_k": {}, + "attention_mask": {}, + "read_index": self.read_index, # slicing is done during building + "write_index": self.write_index, # slicing is done during building + "cache": self.cache, + "use_cache": False, + } + + # For the attributes that are dict of tensors, we replace the dict with a tensor if there is only one entry + layer_types = list(self.cumulative_seqlens_k.keys()) + if len(layer_types) > 1: + for layer_type, seqlens_k in self.cumulative_seqlens_k.items(): + kwargs["cu_seq_lens_k"][layer_type] = seqlens_k[: b_size + 1] + kwargs["max_seqlen_k"][layer_type] = self.max_seqlen_k[layer_type] + if self.attention_mask is not None: + k_len = seqlens_k[b_size] if self.slice_inputs else self.attention_mask[layer_type].size(-1) + kwargs["attention_mask"][layer_type] = self.attention_mask[layer_type][..., :q_len, :k_len] + else: + layer_type = layer_types[0] + kwargs["cu_seq_lens_k"] = self.cumulative_seqlens_k[layer_type][: b_size + 1] + kwargs["max_seqlen_k"] = self.max_seqlen_k[layer_type] + if self.attention_mask is not None: + k_len = self.cumulative_seqlens_k[layer_type][b_size] + k_len = k_len if self.slice_inputs else self.attention_mask[layer_type].size(-1) + kwargs["attention_mask"] = self.attention_mask[layer_type][..., :q_len, :k_len] + + if self.attention_mask is None: + kwargs["attention_mask"] = None + return kwargs + + def __repr__(self): + return ( + f"ContinuousBatchProcessor(input_queue={self.input_queue}, output_queue={self.output_queue}, " + f"active_requests={self.scheduler.active_requests}, waiting_requests={self.scheduler.waiting_requests})" + + self.get_model_kwargs().__repr__() + ) + + @traced + def _get_new_requests(self): + """Pull new requests from the input queue and add to waiting list.""" + while not self.input_queue.empty(): + try: + state = self.input_queue.get_nowait() + if state is None: # Sentinel value + continue + self.scheduler.add_waiting_request(state) + + except queue.Empty: + break + except Exception as e: + logger.error(f"Error processing new request: {e}", exc_info=True) + state: RequestState = locals().get("state") + if state is not None: + self._handle_request_error(e, state) + + @traced + def _handle_request_error(self, error, state: RequestState): + """Handle general request processing error.""" + state.status = RequestStatus.FAILED + state.error = str(error) + + # Include any generated tokens if this is an active request + if isinstance(state.request_id, str): + state.static_outputs = self.scheduler.get_active_request_static_outputs(state.request_id) + else: + state.static_outputs = [] + + self.metrics.record_request_completion(state.created_time, state.request_id) + self.output_queue.put(state.to_generation_output()) + + @traced + def prepare_next_batch(self) -> bool: + """Prepare tensors and metadata for the next model forward pass. Returns True if there are requests to process, + False otherwise.""" + + # Get new requests from the queue, stop if there are no pending requests + self._get_new_requests() + self.scheduler.clear_cancelled_requests() + if not self.scheduler.has_pending_requests(): + return False + self.metrics.record_queue_metrics(len(self.scheduler.active_requests), len(self.scheduler.waiting_requests)) + + # Schedule the next batch of requests, stop if there are no requests in the batch + self.requests_in_batch = self.scheduler.schedule_batch(self.max_batch_tokens) + if not self.requests_in_batch: + return False + self.metrics.record_batch_metrics(self.requests_in_batch) + + # Reset the static tensors used for storage + self.reset_static_tensors() # TODO: with slice_inputs, this might be unnecessary + + # Prepare accumulators + self.total_query_length = 0 + self.total_key_length = 0 + self.total_batch_size = 0 + + input_ids = [] + position_ids = [] + cumulative_seqlens_q = [0] + logits_indices = [] + + if isinstance(self.cumulative_seqlens_k, dict): + cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k} + else: + cumulative_seqlens_k = [0] + + read_index = [[] for _ in range(self.cache.num_groups)] + write_index = [[] for _ in range(self.cache.num_groups)] + + # Go through all the requests in the batch + for state in self.requests_in_batch: + # First we retrieve the lengths related to the request + past_length = state.position_offset + query_length = len(state.prompt_ids) + seqlens_k = self.cache.get_seqlens_k(state.request_id, past_length, query_length) + + # Then we update the total lengths that are used for slicing + self.total_query_length += query_length + # total_key_length is used to slice the keys so we need to take the max of all the key lengths + self.total_key_length += max(seqlens_k.values()) + self.total_batch_size += 1 + # And the attribute tracking the position in the request object + state.position_offset += query_length + + # Then we accumulate for the object used in the kwargs + input_ids.extend(state.prompt_ids) + position_ids.extend(range(past_length, past_length + query_length)) + cumulative_seqlens_q.append(cumulative_seqlens_q[-1] + query_length) + self.max_seqlen_q = max(self.max_seqlen_q, query_length) + + if not state.remaining_prompt_ids: + logits_indices.append(cumulative_seqlens_q[-1] - 1) + + for layer_type, layer_type_seqlen_k in seqlens_k.items(): + cumulative_seqlens_k[layer_type].append(cumulative_seqlens_k[layer_type][-1] + layer_type_seqlen_k) + self.max_seqlen_k[layer_type] = max(self.max_seqlen_k[layer_type], layer_type_seqlen_k) + + self.cache.extend_read_indices(state.request_id, past_length, query_length, read_index) + self.cache.extend_write_indices(state.request_id, past_length, query_length, write_index) + + # When looping over request is done, we can build the actual tensors + self._build_tensors( + input_ids, + position_ids, + read_index, + write_index, + cumulative_seqlens_q, + cumulative_seqlens_k, + logits_indices, + ) + self.metrics.record_kv_cache_memory_metrics(self.cache) + + if logger.isEnabledFor(logging.DEBUG): + if isinstance(self.cumulative_seqlens_k, dict): + ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k) + else: + ck = cumulative_seqlens_k[-1] + logger.debug( + f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, " + f"Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. " + f"cum KV: {ck}, free blocks: {self.cache.get_num_free_blocks()}" + ) + return True + + @traced + def _build_tensors( + self, + input_ids: list[int], + position_ids: list[int], + read_index: list[list[int]], + write_index: list[list[int]], + cumulative_seqlens_q: list[int], + cumulative_seqlens_k: Union[list[int], dict[str, list[int]]], + logits_indices: list[int], + ) -> None: + """Builds the actual tensors for the current batch, by modifying the already allocated tensors in place.""" + to_tensor = partial(torch.tensor, **self.tensor_metadata) + + # Those kwargs always have the same type regardless of the model + self.input_ids[:, : len(input_ids)] = to_tensor(input_ids) + self.position_ids[:, : len(position_ids)] = to_tensor(position_ids) + self.cumulative_seqlens_q[: len(cumulative_seqlens_q)] = to_tensor(cumulative_seqlens_q) + self.logits_indices[: len(logits_indices)] = to_tensor(logits_indices) + + # Those kwargs are either dict of tensors or tensors, so we need to handle both cases + for layer_type, layer_type_seqlens_k in cumulative_seqlens_k.items(): + self.cumulative_seqlens_k[layer_type][: len(layer_type_seqlens_k)] = to_tensor(layer_type_seqlens_k) + if self.attention_mask is not None: + build_attention_mask( + attention_mask=self.attention_mask[layer_type], + cumulative_seqlens_q=cumulative_seqlens_q, + cumulative_seqlens_k=layer_type_seqlens_k, + sliding_window=self.sliding_window if layer_type == "sliding_attention" else 1, + ) + + # The index only contain references to the storage tensors, so we update the storage and their references + self.read_index = [] + self.write_index = [] + for i, group_read_indices, group_write_indices in zip(count(), read_index, write_index): + # Write in the actual tensors + self.read_index_storage[i][: len(group_read_indices)] = to_tensor(group_read_indices) + self.write_index_storage[i][: len(group_write_indices)] = to_tensor(group_write_indices) + # Slice to the right size + r = len(group_read_indices) if self.slice_inputs else self.read_index_storage[i].size(-1) + w = len(group_write_indices) if self.slice_inputs else self.write_index_storage[i].size(-1) + # Add to the index + self.read_index.append(self.read_index_storage[i][:r]) + self.write_index.append(self.write_index_storage[i][:w]) + + @traced + def _sync(self): + if self.output_ids is not None: + try: + out = self.output_ids.tolist()[0] # should be the only sync we do + except Exception: + out = [0, 1] + else: + out = [0, 0] + return out + + @traced + def _maybe_send_output(self, state: RequestState, token: int): + """Send output to the queue based on streaming mode and request state.""" + if self.streaming: + self.output_queue.put(state.to_generation_output()) + elif state.status == RequestStatus.FINISHED: + self.output_queue.put(state.to_generation_output()) + + @traced + def update_batch(self): + """Update request states based on generated tokens.""" + out_tokens = self._sync() + finished_request_ids = [] + for i, state in enumerate(self.requests_in_batch): + req_id = state.request_id + if len(state.remaining_prompt_ids) == 0: + self.metrics.record_ttft_metric(state.created_time, state.request_id) + state.status = RequestStatus.DECODING + token = out_tokens[self.logits_indices[i]] + state.prompt_ids = [token] + if state.update_with_token(token): + self.metrics.record_request_completion(state.created_time, state.request_id) + self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction)) + finished_request_ids.append(req_id) + self._maybe_send_output(state, token) + elif state.status == RequestStatus.PREFILLING_SPLIT: + state.status = RequestStatus.SPLIT_PENDING_REMAINDER + if self.cache.get_num_free_blocks() == 0: + raise ValueError("No more free blocks") + + @traced + def has_pending_requests(self) -> bool: + """Check if there are any active or waiting requests.""" + return self.scheduler.has_pending_requests() + + @traced + def handle_batch_error(self, error): + """Handle errors during batch processing.""" + failed_reqs = self.requests_in_batch + for req in failed_reqs: + self._handle_request_error(error, req) + self.scheduler.finish_request(req.request_id) + + @traced + def fail_all_requests(self, error): + """Fail all active requests with the given error. + + Args: + error: The error to report in the failure message + """ + + requests = list(self.scheduler.active_requests.values()) + for state in requests: + self._handle_request_error(error, state) + self.scheduler.finish_request(state.request_id) + + # Also fail any requests in the waiting queue + for req_id in list(self.scheduler.waiting_requests.keys()): + state = self.scheduler.waiting_requests.pop(req_id) + self._handle_request_error(error, state) + + # Clear the ordering queue + self.scheduler.waiting_requests_order.clear() + + +# Manager Class (User Interface) +@attach_tracer() +class ContinuousBatchingManager: + """Manager for handling continuous batching of generation requests. + + This class provides the user interface for submitting generation requests, + retrieving results, and managing the background generation thread. + """ + + def __init__( + self, + model, + generation_config: GenerationConfig, + manual_eviction: bool = False, + max_queue_size=0, + streaming: bool = True, + slice_inputs: bool = True, + ): + """Initialize the continuous batching manager. + + Args: + model: The language model for generation + generation_config: Configuration for generation parameters + max_queue_size: Maximum size of the request queue (0 = unlimited) + streaming: Whether to stream tokens as they are generated + """ + self.model = model.eval() + generation_config = model.generation_config if generation_config is None else generation_config + self.generation_config = generation_config + self.input_queue = queue.Queue(maxsize=max_queue_size) + self.output_queue = queue.Queue() + self.stop_event = threading.Event() + self.streaming = streaming + self.log_prob_generation = getattr(generation_config, "log_prob_generation", False) + self._generation_thread = None + self._request_counter = 0 + self._request_lock = threading.Lock() + self.model.generation_config.top_p = None + self.do_sample = getattr(generation_config, "do_sample", True) + self.logit_processor = self.model._get_logits_processor(generation_config) + self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", False) # TODO: same as do_sample + self.profile = getattr(generation_config, "profile", False) + self.manual_eviction = manual_eviction + self.batch_processor: Optional[ContinuousBatchProcessor] = None + self.slice_inputs = slice_inputs + + if self.use_cuda_graph: + raise NotImplementedError("Cuda graphs are not supported yet") + + @traced + def start(self): + """Start the background generation thread.""" + if self._generation_thread is not None and self._generation_thread.is_alive(): + logger.warning("Manager thread is already running.") + return + + self._result_queue = queue.Queue() + self._generation_thread = threading.Thread(target=self._run_generation_loop) + self._generation_thread.start() + + def is_running(self): + """Check if the background generation thread is running.""" + return self._generation_thread is not None and self._generation_thread.is_alive() + + def stop(self, block: bool = False, timeout: Optional[float] = None): + """Signal the background thread to stop. + + Args: + block: Whether to wait for the thread to stop + timeout: Maximum time to wait for the thread to stop + """ + if self._generation_thread is None: + logger.warning("Manager not started.") + return + + if not self.stop_event.is_set(): + self.stop_event.set() + logger.info("Stopping continuous batching manager...") + + if block: + self.join(timeout) + + def join(self, timeout: Optional[float] = None): + """Wait for the background thread to finish. + + Args: + timeout: Maximum time to wait for the thread to stop + """ + if self._generation_thread is not None: + self._generation_thread.join(timeout=timeout) + if self._generation_thread.is_alive(): + logger.warning("Generation thread did not exit after join timeout.") + else: + logger.info("Continuous Batching Manager stopped.") + self._generation_thread = None + + def add_request( + self, input_ids: list[int], request_id: Optional[str] = None, max_new_tokens: Optional[int] = None + ) -> str: + """Add a new generation request to the queue. + + Args: + input_ids: Input token IDs to use as prompt + request_id: Optional custom request ID (auto-generated if None) + **kwargs: Additional generation parameters + + Returns: + str: The request ID + """ + if request_id is None: + with self._request_lock: + request_id = f"req_{self._request_counter}" + self._request_counter += 1 + + max_new_tokens = self.generation_config.max_new_tokens if max_new_tokens is None else max_new_tokens + + # NOTE: do we want to handle a case when the user wants token ids returned instead of decoded text? + state = RequestState( + request_id=request_id, + prompt_ids=list(input_ids), + full_prompt_ids=list(input_ids), + max_new_tokens=max_new_tokens, + eos_token_id=self.generation_config.eos_token_id, + ) + + # Use block=True with timeout to handle backpressure if queue is full + self.input_queue.put(state, block=True, timeout=10) # XXX: pass timeout as fn arg? + logger.debug(f"Added request {request_id} to queue.") + return request_id + + def add_requests(self, inputs: list[list[int]], **kwargs): + for input_ids in inputs: + self.add_request(input_ids, **kwargs) + + def cancel_request(self, request_id: str): + """Cancel a request by its ID. + + Args: + request_id: The ID of the request to cancel + """ + if self.batch_processor is not None: + self.batch_processor.scheduler.set_request_cancellation(request_id) + + def get_result(self, request_id=None, timeout=None) -> Optional[GenerationOutput]: + """Retrieve one result from the output queue. + + Args: + timeout: Maximum time to wait for a result + + Returns: + Optional[GenerationOutput]: The result data or None if timeout + """ + if self._generation_thread is None and self.output_queue.empty(): + return None + try: + result = self.output_queue.get(block=True, timeout=timeout) + if request_id is not None and result.request_id != request_id: + self.output_queue.put(result) + return None + logger.debug(f"Retrieved result for request {result.request_id}") + return result + except queue.Empty: + return None + + def __iter__(self): + """Iterate over results as they become available.""" + while self._generation_thread is not None and self._generation_thread.is_alive(): + result = self.get_result(timeout=0.1) + if result is not None: + yield result + + def request_id_iter(self, request_id): + """Iterate over results matching a specific request id as they become available.""" + request_cancelled = False + while self._generation_thread is not None and self._generation_thread.is_alive() and not request_cancelled: + result = self.get_result(request_id=request_id, timeout=0.1) + if result is not None: + yield result + if self.batch_processor is not None: + request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id) + + @staticmethod + def supported_attention_implementations() -> set[str]: + return {"eager_paged", "sdpa_paged", "flash_attention_2"} + + @staticmethod + def default_attention_implementation() -> str: + return "sdpa_paged" + + @traced + def warmup(self, batch_processor): + stream = torch.cuda.Stream(device=self.model.device) + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + # Warmup the model with a dummy forward pass + self._generation_step(batch_processor) + torch.cuda.current_stream().wait_stream(stream) + + self.graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.graph, stream=stream): + self._generation_step(batch_processor) + + @traced + # @torch.compile + def _generation_step(self, batch_processor: ContinuousBatchProcessor): + """Perform a single generation step. This is cuda graphed""" + batch_data = batch_processor.get_model_kwargs() + with torch.no_grad(): + logits = self._model_forward(batch_data) + if self.log_prob_generation: + batch_processor.output_probs.copy_(logits) # TODO + probs = self._process_logit(batch_data, logits) + self._sample(batch_processor, probs) + + @traced(span_name="model_forward") + def _model_forward(self, batch_data): + return self.model(**batch_data).logits + + @traced(span_name="logit_processing") + def _process_logit(self, batch_data, logits): + # Pass continuous batching context to logits processor if it supports it. TODO we should find a way to make this a little bit cleaner! + if hasattr(self.logit_processor, "set_continuous_batching_context"): + self.logit_processor.set_continuous_batching_context( + batch_data["logits_indices"], batch_data["cu_seq_lens_q"] + ) + + # Handle shape compatibility: logit processors expect 2D tensors [batch_size, vocab_size] + # but continuous batching always produces 3D tensors [batch_size, seq_len, vocab_size] + batch_size, seq_len, vocab_size = logits.shape + logits_2d = logits.view(batch_size * seq_len, vocab_size) + input_ids_2d = batch_data["input_ids"].view(batch_size * seq_len) + + # Process with 2D tensors + processed_logits_2d = self.logit_processor(input_ids_2d, logits_2d) + + # Reshape back to 3D + return processed_logits_2d.view(batch_size, seq_len, vocab_size) + + @traced(span_name="sampling") + def _sample(self, batch_processor: ContinuousBatchProcessor, probs): + if self.do_sample: # sample + probs = nn.functional.softmax(probs, dim=-1) + # probs[0] has shape [seq_len, vocab_size], multinomial returns [seq_len, 1] + next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(-1) # Now [seq_len] + # Add batch dimension back to match argmax output + next_tokens = next_tokens.unsqueeze(0) # Now [1, seq_len] + else: + next_tokens = torch.argmax(probs, dim=-1) # Already [1, seq_len] + + tokens = next_tokens.size(1) # Get seq_len dimension + batch_processor.output_ids[:, :tokens].copy_(next_tokens) + + def _run_generation_loop(self): + """Main processing loop running in the background thread.""" + batch_processor = None + try: + ref_time = perf_counter() + paged_attention_cache = PagedAttentionCache( + self.model.config, + self.generation_config, + self.model.device, + self.model.dtype, + tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting + ) + logger.debug(f"PagedAttentionCache created in {perf_counter() - ref_time} seconds") + + scheduler = None + if hasattr(self.generation_config, "scheduler"): + scheduler = SCHEDULER_MAPPING.get(self.generation_config.scheduler, None) + if scheduler is None: + logger.warning(f"Scheduler '{scheduler}' not found. Defaulting to FIFO.") + scheduler = FIFOScheduler + else: + # Default to fifo + scheduler = FIFOScheduler + + ref_time = perf_counter() + batch_processor = ContinuousBatchProcessor( + paged_attention_cache, + self.model.config, + self.generation_config, + self.input_queue, + self.output_queue, + self.stop_event, + self.model.device, + self.model.dtype, + scheduler(paged_attention_cache, self.manual_eviction), + self.streaming, + self.manual_eviction, + slice_inputs=self.slice_inputs, + ) + self.batch_processor = batch_processor + self.current_batch = 0 + logger.debug(f"batch_processor created in {perf_counter() - ref_time} seconds") + while (not self.stop_event.is_set()) or batch_processor.has_pending_requests(): + self._inner_generation_loop(batch_processor) + self.current_batch += 1 + + except Exception as e: + logger.error(f"Error in generation loop: {e}", exc_info=True) + self._handle_critical_error(e, batch_processor) + finally: + logger.info("Generation loop finished.") + + @traced(span_name="generation_loop") + def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor): + if torch.cuda.is_available(): + torch.cuda.synchronize() + if not batch_processor.prepare_next_batch(): + return + if logger.level <= logging.DEBUG: + device, total, reserved, allocated = get_device_and_memory_breakdown() + logger.debug(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}") + if torch.cuda.is_available() and self.use_cuda_graph: + if self.current_batch == 0: + self.warmup(batch_processor) + elif hasattr(self, "graph"): + try: + self._graph_replay() + except Exception as e: + logger.error(f"Model forward pass failed: {e}", exc_info=True) + batch_processor.handle_batch_error(e) + return + else: + self._generation_step(batch_processor) + else: + self._generation_step(batch_processor) + if torch.cuda.is_available(): + torch.cuda.synchronize() + batch_processor.update_batch() + + @traced(span_name="graph_replay") + def _graph_replay(self): + self.graph.replay() + + @traced + def _handle_critical_error(self, error, batch_processor: Optional[ContinuousBatchProcessor]): + """Handle critical errors that terminate the generation loop.""" + # Signal stop + self.stop_event.set() + + # Fail pending requests in input queue + try: + while True: + req_data = self.input_queue.get_nowait() + if batch_processor is not None: + batch_processor._handle_request_error(error, req_data) + except queue.Empty: + pass + + # Fail active requests + if batch_processor is not None: + batch_processor.fail_all_requests(error) + + @traced + def evict_request_from_cache(self, request_id: str): + """Evict a request from the cache. It is assumed that the request is already finished.""" + if not self.manual_eviction: + raise RuntimeError("Manual eviction is not enabled for this manager.") + if self.batch_processor is not None: + self.batch_processor.scheduler.finish_request(request_id) + + +class ContinuousMixin: + """Mixin class for models to add continuous batching capabilities.""" + + def init_continuous_batching( + self, + generation_config: Optional[GenerationConfig] = None, + manual_eviction: bool = False, + max_queue_size: int = 0, + streaming: bool = False, + slice_inputs: bool = True, + ) -> ContinuousBatchingManager: + """Initialize a manager for continuous batching inference. + + Args: + generation_config: Custom generation configuration + max_queue_size: Maximum size of the input request queue + streaming: Whether to stream tokens as they are generated + + Returns: + `ContinuousBatchingManager`: The manager instance to add requests and retrieve results. + """ + if not hasattr(self, "config") or not hasattr(self, "device") or not hasattr(self, "dtype"): + raise AttributeError("Model must have 'config', 'device', and 'dtype' attributes.") + + gen_config = generation_config if generation_config is not None else self.generation_config + if gen_config is None: + raise ValueError("A GenerationConfig must be provided or set in the model.") + + if gen_config.eos_token_id is None: + logger.warning("`eos_token_id` not set in GenerationConfig. Setting to -1 (disabled).") + gen_config.eos_token_id = -1 + + # Create and return the manager + return ContinuousBatchingManager( + model=self, + generation_config=gen_config, + manual_eviction=manual_eviction, + max_queue_size=max_queue_size, + streaming=streaming, + slice_inputs=slice_inputs, + ) + + @traced + @torch.inference_mode() + def generate_batch( + self, + inputs: list[list[int]], + generation_config: Optional[GenerationConfig] = None, + progress_bar: bool = True, + slice_inputs: bool = True, + **kwargs, + ) -> list[list[int]]: + """Generate sequences for a batch of prompts using continuous batching. + + Args: + inputs: List of input token sequences (prompts) + generation_config: Optional generation configuration + **kwargs: Additional generation parameters + + Returns: + `list[list[int]]`: A list containing the generated sequences (including prompt tokens + if not handled otherwise) for each input prompt, in the same order. + Returns an empty list `[]` for requests that failed. + """ + if not inputs: + return [] + if logger.getEffectiveLevel() <= logging.DEBUG: + logger.warning("Progress bar is disabled when logger level is less than DEBUG") + progress_bar = False + + # Initialize manager with the batch inputs + manager = self.init_continuous_batching(generation_config=generation_config, slice_inputs=slice_inputs) + manager.start() + results = {} + num_requests = len(inputs) + try: + from tqdm.contrib.logging import logging_redirect_tqdm + + with logging_redirect_tqdm([logger]): + with tqdm( + total=num_requests, + disable=(not progress_bar), + desc=f"Solving {num_requests} requests", + unit="request", + ) as pbar: + manager.add_requests(inputs, **kwargs) + finished_count = 0 + while finished_count < num_requests: + result = manager.get_result(timeout=1) + if result: + req_id = result.request_id + if result.status == RequestStatus.FINISHED: + results[req_id] = result + finished_count += 1 + pbar.update(1) + else: + if not manager.is_running(): + logger.error("Generation thread terminated unexpectedly.") + break + + except Exception as e: + logger.error(f"Error during batch generation: {e}", exc_info=True) + finally: + manager.stop(block=True, timeout=5.0) + return results diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/requests.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/requests.py new file mode 100644 index 0000000000000000000000000000000000000000..c7842735e7e449eeac7bbc766ed290a3791b5c1f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/requests.py @@ -0,0 +1,204 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + +import torch + +from ...utils.logging import logging +from ...utils.metrics import traced + + +# We centralize the logger here to coordinate between logging and progress bar +logger = logging.getLogger("ContinuousBatchingLogger") +# logger.setLevel(logging.INFO) + + +def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]: + if torch.cuda.is_available(): + device = torch.device("cuda") + torch.cuda.empty_cache() + torch.cuda.synchronize() + total_memory = torch.cuda.get_device_properties(device).total_memory + reserved_memory = torch.cuda.memory_reserved(device) + allocated_memory = torch.cuda.memory_allocated(device) + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + device = torch.device("mps") + # MPS memory reporting (PyTorch 2.0+) + total_memory = torch.mps.driver_allocated_memory() + allocated_memory = total_memory - torch.mps.recommended_max_memory() + reserved_memory = 0 # MPS does not track reserved separately + else: + device = torch.device("cpu") + total_memory = None + reserved_memory = 0 + allocated_memory = 0 + return device, total_memory, reserved_memory, allocated_memory + + +class RequestStatus(Enum): + """Status of a generation request through its lifecycle.""" + + PENDING = "pending" + PREFILLING = "prefilling" + PREFILLING_SPLIT = "prefilling_split" + SPLIT_PENDING_REMAINDER = "split_pending_remainder" + DECODING = "decoding" + FINISHED = "finished" + FAILED = "failed" + + +@dataclass +class GenerationOutput: + """Tracks the output of a generation request. + + Attributes: + request_id (str): The ID of the generation request. + prompt_ids (list[int]): The IDs of the prompt tokens. + generated_tokens (list[int]): The generated tokens. + logprobs (list[float]): The log probabilities of the generated tokens. + error (Optional[str]): Any error message associated with the request. When None, the request was successful. + status (RequestStatus): The status of the request. + created_time (float): The time the request was created. + """ + + request_id: str + prompt_ids: list[int] = field(default_factory=list) + generated_tokens: list[int] = field(default_factory=list) + logprobs: list[float] = field(default_factory=list) + error: Optional[str] = None + status: RequestStatus = RequestStatus.PENDING + created_time: float = field(default_factory=time.time) + + +@dataclass +class RequestState: + """Tracks the state of a generation request through its lifecycle. + + Attributes: + request_id (str): The ID of the generation request. + full_prompt_ids (list[int] | None): The tokens IDs of the full prompt. + prompt_ids (list[int] | None): The tokens IDs currently being processed. + remaining_prompt_ids (list[int]): The tokens IDs remaining to be processed (for split requests). + static_outputs (list[int]): The generated tokens. + allocated_blocks (int): The number of blocks allocated to the request. + position_offset (int): The current position in the sequence for position_ids. + status (RequestStatus): The status of the request: can be one of PENDING, PREFILLING, PREFILLING_SPLIT, + SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED + max_new_tokens (int): The maximum number of new tokens to generate. + eos_token_id (int): The ID of the end-of-sequence token. + created_time (float): The time the request was created. + error (Optional[str]): Any error message associated with the request. When None, has had no error yet. + """ + + # Required fields + request_id: str + full_prompt_ids: Optional[list[int]] = None # Full initial prompt + prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated) + remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process + static_outputs: list[int] = field(default_factory=list) # Generated tokens + allocated_blocks: int = 0 # Number of blocks allocated to the request + position_offset: int = 0 # Current position in the sequence for position_ids + _status: RequestStatus = RequestStatus.PENDING # Status of the request, hidden behind a property + max_new_tokens: int = 20 # Maximum number of new tokens to generate + eos_token_id: int = -1 # ID of the end-of-sequence token + created_time: float = field(default_factory=time.time) # Time the request was created + error: Optional[str] = None # Error message if the request failed + lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished) + + @property + def status(self) -> RequestStatus: + return self._status + + @status.setter + def status(self, value: RequestStatus): + if self._status == RequestStatus.PENDING: + self.lifespan = (time.time(), -1) + elif value == RequestStatus.FINISHED: + self.lifespan = (self.lifespan[0], time.time()) + self.log_end_of_request() + self._status = value + + def log_end_of_request(self): + prefill_len = len(self.full_prompt_ids) + decode_len = self.generated_len() + start_time = self.lifespan[0] - self.created_time + end_time = self.lifespan[1] - self.created_time + logger.info( + f"Request {self.request_id} finished: {prefill_len = } {decode_len = } {start_time = } {end_time = }" + ) + + def current_len(self) -> int: + """Get the current length of the sequence (prompt + generated tokens).""" + return self.position_offset + + def generated_len(self) -> int: + """Get the number of tokens generated so far.""" + return len(self.static_outputs) + + # TODO: this logic seems one token off, check it out + @traced + def update_with_token(self, token_id: int) -> bool: + """Update the request with a newly generated token and check for completion. + + Args: + token_id: The token ID to add to the output sequence + + Returns: + bool: True if the request is now complete, False otherwise + """ + # Only update if we're in decoding state + if self.status != RequestStatus.DECODING: + return False + + is_eos = token_id == self.eos_token_id and self.eos_token_id != -1 + is_max_len = self.generated_len() >= self.max_new_tokens + + # Only add the token if we're not finishing due to max length + # (EOS tokens should still be added to the output) + if not (is_max_len and not is_eos): + self.static_outputs.extend([token_id]) + + if is_eos or is_max_len: + self.status = RequestStatus.FINISHED + return True + return False + + def __repr__(self): + msg = [ + f"request_id={self.request_id}", + f"status={self._status}", + f"out_tokens={self.generated_len()}", + f"query_length={len(self.prompt_ids)}", + f"remaining_tokens={len(self.remaining_prompt_ids)}", + f"kv_length={self.position_offset}", + f"full_prompt_length={len(self.full_prompt_ids)}", + f"allocated_blocks={self.allocated_blocks}", + f"generated_tokens={self.static_outputs}", + ] + return "RequestState(\n\t" + ",\n\t".join(msg) + "\n)" + + def to_generation_output(self): + """Convert the request state to a GenerationOutput object.""" + return GenerationOutput( + request_id=self.request_id, + prompt_ids=self.full_prompt_ids, + status=self.status, + generated_tokens=self.static_outputs, + logprobs=[], + error=self.error, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/scheduler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..67cddbf14190a65f367b87bf09cd3f7d323b6c80 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/scheduler.py @@ -0,0 +1,300 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +from abc import ABC, abstractmethod +from collections import deque + +from ...utils.metrics import attach_tracer, traced +from .cache import PagedAttentionCache +from .requests import RequestState, RequestStatus + + +class Scheduler(ABC): + """ + Abstract base class for scheduling requests in the continuous batch processor. Schedulers manage the lifecycle of + requests from when they are added to the waiting queue to when they are scheduled for processing. Different + schedulers implement different strategies for prioritizing and batching requests. + """ + + def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False): + self.active_requests: dict[str, RequestState] = {} + self.waiting_requests: dict[str, RequestState] = {} + self.waiting_requests_order: deque[str] = deque() + self.cache = cache + self.retain_cache_on_finish = retain_cache_on_finish + self._cancellation_lock = threading.Lock() + self._requests_to_cancel: set[str] = set() + + @traced + def add_waiting_request(self, state: RequestState): + """Adds a request to the waiting list.""" + if self.retain_cache_on_finish and state.request_id in self.active_requests: + old_state = self.active_requests.pop(state.request_id) + state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :] # XXX: check for indexing error? + state.allocated_blocks = old_state.allocated_blocks + state.position_offset = old_state.position_offset + self.waiting_requests[state.request_id] = state + self.waiting_requests_order.append(state.request_id) + + @abstractmethod + def schedule_batch(self, token_budget: int) -> list[RequestState]: + """Schedules requests for the next batch based on available token budget. This method selects which requests + should be processed in the current batch, considering the token budget and the scheduler's prioritization rules. + The token_budget is the maximum number of tokens that can be processed in this batch.""" + pass + + @traced + def has_pending_requests(self) -> bool: + """Checks if there are requests ready to be processed.""" + return len(self.active_requests) or len(self.waiting_requests) + + @traced + def finish_request(self, request_id: str, evict_from_cache: bool = True): + """Completes processing of a request and optionally frees its allocated cache blocks. This method is called + when a request has finished generation or encountered an error. + """ + if evict_from_cache: + self.cache.free_blocks(request_id) + if request_id in self.active_requests: + del self.active_requests[request_id] + + @traced + def get_active_request_static_outputs(self, request_id: str) -> list[int]: + """Gets generated tokens for an active request.""" + if request_id in self.active_requests: + return self.active_requests[request_id].static_outputs + return [] + + @traced + def set_request_cancellation(self, request_id: str): + """Marks a request for cancellation.""" + with self._cancellation_lock: + self._requests_to_cancel.add(request_id) + + @traced + def clear_cancelled_requests(self): + """Remove all cancelled requests from active and waiting queues.""" + with self._cancellation_lock: + for request_id in self._requests_to_cancel: + if request_id in self.active_requests: + del self.active_requests[request_id] + if request_id in self.waiting_requests: + del self.waiting_requests[request_id] + if request_id in self.waiting_requests_order: + self.waiting_requests_order.remove(request_id) + self.cache.free_blocks(request_id) + self._requests_to_cancel = set() + + @traced + def request_is_cancelled(self, request_id: str) -> bool: + """Checks if a request has been cancelled or removed.""" + return request_id in self._requests_to_cancel or ( + request_id not in self.active_requests and request_id not in self.waiting_requests + ) + + @traced + def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int) -> bool: + """Allocate additional cache blocks for a request if the currently allocated blocks are insufficient to + accommodate the next tokens. It calculates how many blocks are needed based on the request's current + cache occupancy and the number of tokens to be processed. The allocation itself is done by the CacheAllocator + objects. Returns a boolean indicating if the allocation was successful or not. + """ + # 1. we check that the occupancy is less than the requested length + # 2. we allocate enough blocks to cover the requested length + current_len = state.current_len() + occupancy = state.allocated_blocks * self.cache.block_size - current_len + if occupancy < len_next_tokens or state.allocated_blocks == 0: + blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1 + allocated = self.cache.allocate_blocks(blocks_needed, state.request_id) + if allocated is None: + return False + state.allocated_blocks += allocated + return True + + @traced(span_name="prepare_request") + def _prepare_request_for_processing( + self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str] + ): + """Prepares a request for processing in the current batch.""" + request_tokens = ( + state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids + ) + if len(request_tokens) < token_budget: + # Can process the entire prompt/remainder + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING + state.prompt_ids = state.remaining_prompt_ids + state.remaining_prompt_ids = [] + else: + # Need to split the request + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING_SPLIT + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING_SPLIT + state.remaining_prompt_ids = request_tokens[token_budget:] + state.prompt_ids = request_tokens[:token_budget] + + +@attach_tracer() +class FIFOScheduler(Scheduler): + """This scheduler processes requests in the order they arrive, meaning decoding requests has priority over + prefilling requests. Additionally, it includes a safety margin mechanism to prevent cache exhaustion. By default, + when 80% of the cache is full, new requests will not be scheduled to prioritize decoding active requests.""" + + def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False, safety_margin: float = 0.2): + """Initializes the FIFO scheduler. The safety margin is the percentage of free blocks under which we stop + scheduling new prefill requests, so safety_margin = 0.1 means that when there is less than 10% of free blocks, + or equivalently when more than 90% of blocks are already allocated, we stop scheduling new prefill requests. + """ + super().__init__(cache, retain_cache_on_finish) + self.safety_margin = safety_margin + + @traced + def schedule_batch(self, token_budget: int) -> list[RequestState]: + priority_states: list[RequestState] = [] + second_priority_states: list[RequestState] = [] + scheduled_requests = [] + + for state in self.active_requests.values(): + if state.status == RequestStatus.DECODING: + priority_states.append(state) + if state.status in [RequestStatus.SPLIT_PENDING_REMAINDER, RequestStatus.PREFILLING_SPLIT]: + second_priority_states.append(state) + + # Add waiting requests to second priority + for req_id in self.waiting_requests_order: + second_priority_states.append(self.waiting_requests[req_id]) + + candidates = priority_states + second_priority_states + request_ids_to_remove_from_waiting = set() + safety_margins = self.safety_margin * self.cache.num_blocks + + for state in candidates: + # If we are out the safety margin, we only accept decoding requests or the first prefill request + num_free_blocks = self.cache.get_num_free_blocks() + outside_safety_margin = num_free_blocks < safety_margins + if outside_safety_margin and scheduled_requests and state.status != RequestStatus.DECODING: + break + + self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting) + request_len = len(state.prompt_ids) + if not self._allocate_blocks_if_needed( + state, len(state.prompt_ids) + ): # don't schedule if we can't allocate blocks + if len(self.cache._free_blocks) == 0: + break + continue + + @traced + def _add_to_scheduled_requests(state: RequestState): + scheduled_requests.append(state) + + _add_to_scheduled_requests(state) + + token_budget -= request_len + + @traced + def _remove_from_waiting_requests(state: RequestState): + req_id = state.request_id + if req_id in self.waiting_requests: + del self.waiting_requests[req_id] + request_ids_to_remove_from_waiting.add(req_id) + + _remove_from_waiting_requests(state) + + if token_budget == 0: + break + + self.waiting_requests_order = deque( + [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting] + ) + + return scheduled_requests + + +# FIXME: prioritize adding from waiting reqs before scheduling `RequestStatus.DECODING` when cache space allows it +@attach_tracer() +class PrefillFirstScheduler(Scheduler): + """Scheduler that prioritizes split prefill requests over decoding requests. This scheduler ensures that split + prefill requests (which are continuations of partially processed prompts) are completed before processing new + decoding requests.""" + + @traced + def schedule_batch(self, token_budget: int) -> list[RequestState]: + priority_states: list[RequestState] = [] + second_priority_states: list[RequestState] = [] + scheduled_requests = [] + + for state in self.active_requests.values(): + # XXX: when cache is full, state can stay on `PREFILLING_SPLIT` so we need to take those into account + if state.status in [RequestStatus.PREFILLING_SPLIT, RequestStatus.SPLIT_PENDING_REMAINDER]: + priority_states.append(state) + elif state.status == RequestStatus.DECODING: + second_priority_states.append(state) + + for req_id in self.waiting_requests_order: + second_priority_states.append(self.waiting_requests[req_id]) + + candidates = priority_states + second_priority_states + + request_ids_to_remove_from_waiting = set() + + for state in candidates: + self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting) + request_len = len(state.prompt_ids) + if not self._allocate_blocks_if_needed( + state, len(state.prompt_ids) + ): # don't schedule if we can't allocate blocks + if len(self.cache._free_blocks) == 0: + break + continue + + @traced + def _add_to_scheduled_requests(state: RequestState): + scheduled_requests.append(state) + + _add_to_scheduled_requests(state) + + token_budget -= request_len + + @traced + def _remove_from_waiting_requests(state: RequestState): + req_id = state.request_id + if req_id in self.waiting_requests: + del self.waiting_requests[req_id] + request_ids_to_remove_from_waiting.add(req_id) + + _remove_from_waiting_requests(state) + + if token_budget == 0: + break + + self.waiting_requests_order = deque( + [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting] + ) + + return scheduled_requests + + +SCHEDULER_MAPPING = { + "fifo": FIFOScheduler, + "prefill_first": PrefillFirstScheduler, +} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57b5747909e091ede05ff07c98254224fbebed97 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_albert import * + from .modeling_albert import * + from .modeling_flax_albert import * + from .modeling_tf_albert import * + from .tokenization_albert import * + from .tokenization_albert_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/configuration_albert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/configuration_albert.py new file mode 100644 index 0000000000000000000000000000000000000000..b60c19d504f05f50abb0988341526f53af8ad4db --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/configuration_albert.py @@ -0,0 +1,170 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ALBERT model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig + + +class AlbertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`AlbertModel`] or a [`TFAlbertModel`]. It is used + to instantiate an ALBERT model according to the specified arguments, defining the model architecture. Instantiating + a configuration with the defaults will yield a similar configuration to that of the ALBERT + [albert/albert-xxlarge-v2](https://huggingface.co/albert/albert-xxlarge-v2) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 30000): + Vocabulary size of the ALBERT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`AlbertModel`] or [`TFAlbertModel`]. + embedding_size (`int`, *optional*, defaults to 128): + Dimensionality of vocabulary embeddings. + hidden_size (`int`, *optional*, defaults to 4096): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_hidden_groups (`int`, *optional*, defaults to 1): + Number of groups for the hidden layers, parameters in the same group are shared. + num_attention_heads (`int`, *optional*, defaults to 64): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 16384): + The dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + inner_group_num (`int`, *optional*, defaults to 1): + The number of inner repetition of attention and ffn. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu_new"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`AlbertModel`] or [`TFAlbertModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + classifier_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for attached classifiers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 3): + End of stream token id. + + Examples: + + ```python + >>> from transformers import AlbertConfig, AlbertModel + + >>> # Initializing an ALBERT-xxlarge style configuration + >>> albert_xxlarge_configuration = AlbertConfig() + + >>> # Initializing an ALBERT-base style configuration + >>> albert_base_configuration = AlbertConfig( + ... hidden_size=768, + ... num_attention_heads=12, + ... intermediate_size=3072, + ... ) + + >>> # Initializing a model (with random weights) from the ALBERT-base style configuration + >>> model = AlbertModel(albert_xxlarge_configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "albert" + + def __init__( + self, + vocab_size=30000, + embedding_size=128, + hidden_size=4096, + num_hidden_layers=12, + num_hidden_groups=1, + num_attention_heads=64, + intermediate_size=16384, + inner_group_num=1, + hidden_act="gelu_new", + hidden_dropout_prob=0, + attention_probs_dropout_prob=0, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + classifier_dropout_prob=0.1, + position_embedding_type="absolute", + pad_token_id=0, + bos_token_id=2, + eos_token_id=3, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_hidden_groups = num_hidden_groups + self.num_attention_heads = num_attention_heads + self.inner_group_num = inner_group_num + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.classifier_dropout_prob = classifier_dropout_prob + self.position_embedding_type = position_embedding_type + + +# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert +class AlbertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ] + ) + + +__all__ = ["AlbertConfig", "AlbertOnnxConfig"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_albert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_albert.py new file mode 100644 index 0000000000000000000000000000000000000000..4cc129366baea19b78ab5e7335fa21c5a371326b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_albert.py @@ -0,0 +1,1349 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ALBERT model.""" + +import math +import os +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from ...utils import ModelOutput, auto_docstring, logging +from .configuration_albert import AlbertConfig + + +logger = logging.get_logger(__name__) + + +def load_tf_weights_in_albert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + print(name) + + for name, array in zip(names, arrays): + original_name = name + + # If saved from the TF HUB module + name = name.replace("module/", "") + + # Renaming and simplifying + name = name.replace("ffn_1", "ffn") + name = name.replace("bert/", "albert/") + name = name.replace("attention_1", "attention") + name = name.replace("transform/", "") + name = name.replace("LayerNorm_1", "full_layer_layer_norm") + name = name.replace("LayerNorm", "attention/LayerNorm") + name = name.replace("transformer/", "") + + # The feed forward layer had an 'intermediate' step which has been abstracted away + name = name.replace("intermediate/dense/", "") + name = name.replace("ffn/intermediate/output/dense/", "ffn_output/") + + # ALBERT attention was split between self and output which have been abstracted away + name = name.replace("/output/", "/") + name = name.replace("/self/", "/") + + # The pooler is a linear layer + name = name.replace("pooler/dense", "pooler") + + # The classifier was simplified to predictions from cls/predictions + name = name.replace("cls/predictions", "predictions") + name = name.replace("predictions/attention", "predictions") + + # Naming was changed to be more explicit + name = name.replace("embeddings/attention", "embeddings") + name = name.replace("inner_group_", "albert_layers/") + name = name.replace("group_", "albert_layer_groups/") + + # Classifier + if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name): + name = "classifier/" + name + + # No ALBERT model currently handles the next sentence prediction task + if "seq_relationship" in name: + name = name.replace("seq_relationship/output_", "sop_classifier/classifier/") + name = name.replace("weights", "weight") + + name = name.split("/") + + # Ignore the gradients applied by the LAMB/ADAM optimizers. + if ( + "adam_m" in name + or "adam_v" in name + or "AdamWeightDecayOptimizer" in name + or "AdamWeightDecayOptimizer_1" in name + or "global_step" in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + print(f"Initialize PyTorch weight {name} from {original_name}") + pointer.data = torch.from_numpy(array) + + return model + + +class AlbertEmbeddings(nn.Module): + """ + Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config: AlbertConfig): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class AlbertAttention(nn.Module): + def __init__(self, config: AlbertConfig): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads}" + ) + + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.attention_head_size = config.hidden_size // config.num_attention_heads + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.output_dropout = nn.Dropout(config.hidden_dropout_prob) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pruned_heads = set() + + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + def prune_heads(self, heads: list[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.query = prune_linear_layer(self.query, index) + self.key = prune_linear_layer(self.key, index) + self.value = prune_linear_layer(self.value, index) + self.dense = prune_linear_layer(self.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.num_attention_heads = self.num_attention_heads - len(heads) + self.all_head_size = self.attention_head_size * self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: bool = False, + ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + key_layer = self.key(hidden_states) + value_layer = self.value(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.transpose(2, 1).flatten(2) + + projected_context_layer = self.dense(context_layer) + projected_context_layer_dropout = self.output_dropout(projected_context_layer) + layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout) + return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,) + + +class AlbertSdpaAttention(AlbertAttention): + def __init__(self, config): + super().__init__(config) + self.dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: bool = False, + ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: + if self.position_embedding_type != "absolute" or output_attentions: + logger.warning( + "AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` . Falling back to " + "the eager attention implementation, but specifying the eager implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward(hidden_states, attention_mask, output_attentions=output_attentions) + + batch_size, seq_len, _ = hidden_states.size() + query_layer = ( + self.query(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + attention_output = torch.nn.functional.scaled_dot_product_attention( + query=query_layer, + key=key_layer, + value=value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=False, + ) + + attention_output = attention_output.transpose(1, 2) + attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size) + + projected_context_layer = self.dense(attention_output) + projected_context_layer_dropout = self.output_dropout(projected_context_layer) + layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout) + return (layernormed_context_layer,) + + +ALBERT_ATTENTION_CLASSES = { + "eager": AlbertAttention, + "sdpa": AlbertSdpaAttention, +} + + +class AlbertLayer(nn.Module): + def __init__(self, config: AlbertConfig): + super().__init__() + + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config) + self.ffn = nn.Linear(config.hidden_size, config.intermediate_size) + self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size) + self.activation = ACT2FN[config.hidden_act] + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions) + + ffn_output = apply_chunking_to_forward( + self.ff_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[0], + ) + hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0]) + + return (hidden_states,) + attention_output[1:] # add attentions if we output them + + def ff_chunk(self, attention_output: torch.Tensor) -> torch.Tensor: + ffn_output = self.ffn(attention_output) + ffn_output = self.activation(ffn_output) + ffn_output = self.ffn_output(ffn_output) + return ffn_output + + +class AlbertLayerGroup(nn.Module): + def __init__(self, config: AlbertConfig): + super().__init__() + + self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)]) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]: + layer_hidden_states = () + layer_attentions = () + + for layer_index, albert_layer in enumerate(self.albert_layers): + layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions) + hidden_states = layer_output[0] + + if output_attentions: + layer_attentions = layer_attentions + (layer_output[1],) + + if output_hidden_states: + layer_hidden_states = layer_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (layer_hidden_states,) + if output_attentions: + outputs = outputs + (layer_attentions,) + return outputs # last-layer hidden state, (layer hidden states), (layer attentions) + + +class AlbertTransformer(nn.Module): + def __init__(self, config: AlbertConfig): + super().__init__() + + self.config = config + self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size) + self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)]) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[BaseModelOutput, tuple]: + hidden_states = self.embedding_hidden_mapping_in(hidden_states) + + all_hidden_states = (hidden_states,) if output_hidden_states else None + all_attentions = () if output_attentions else None + + head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask + + for i in range(self.config.num_hidden_layers): + # Number of layers in a hidden group + layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups) + + # Index of the hidden group + group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups)) + + layer_group_output = self.albert_layer_groups[group_idx]( + hidden_states, + attention_mask, + head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group], + output_attentions, + output_hidden_states, + ) + hidden_states = layer_group_output[0] + + if output_attentions: + all_attentions = all_attentions + layer_group_output[-1] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +@auto_docstring +class AlbertPreTrainedModel(PreTrainedModel): + config: AlbertConfig + load_tf_weights = load_tf_weights_in_albert + base_model_prefix = "albert" + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, AlbertMLMHead): + module.bias.data.zero_() + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`AlbertForPreTraining`]. + """ +) +class AlbertForPreTrainingOutput(ModelOutput): + r""" + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + sop_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: Optional[torch.FloatTensor] = None + sop_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@auto_docstring +class AlbertModel(AlbertPreTrainedModel): + config: AlbertConfig + base_model_prefix = "albert" + + def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + + self.config = config + self.embeddings = AlbertEmbeddings(config) + self.encoder = AlbertTransformer(config) + if add_pooling_layer: + self.pooler = nn.Linear(config.hidden_size, config.hidden_size) + self.pooler_activation = nn.Tanh() + else: + self.pooler = None + self.pooler_activation = None + + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value: nn.Embedding) -> None: + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has + a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT + model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers. + + These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer, + while [2,3] correspond to the two inner groups of the second hidden layer. + + Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more + information about head pruning + """ + for layer, heads in heads_to_prune.items(): + group_idx = int(layer / self.config.inner_group_num) + inner_group_idx = int(layer - group_idx * self.config.inner_group_num) + self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutputWithPooling, tuple]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + + use_sdpa_attention_mask = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + if use_sdpa_attention_mask: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min + + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + + pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a + `sentence order prediction (classification)` head. + """ +) +class AlbertForPreTraining(AlbertPreTrainedModel): + _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] + + def __init__(self, config: AlbertConfig): + super().__init__(config) + + self.albert = AlbertModel(config) + self.predictions = AlbertMLMHead(config) + self.sop_classifier = AlbertSOPHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self) -> nn.Linear: + return self.predictions.decoder + + def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: + self.predictions.decoder = new_embeddings + + def get_input_embeddings(self) -> nn.Embedding: + return self.albert.embeddings.word_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + sentence_order_label: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[AlbertForPreTrainingOutput, tuple]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + sentence_order_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring) Indices should be in `[0, 1]`. `0` indicates original order (sequence A, then + sequence B), `1` indicates switched order (sequence B, then sequence A). + + Example: + + ```python + >>> from transformers import AutoTokenizer, AlbertForPreTraining + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2") + >>> model = AlbertForPreTraining.from_pretrained("albert/albert-base-v2") + + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) + >>> # Batch size 1 + >>> outputs = model(input_ids) + + >>> prediction_logits = outputs.prediction_logits + >>> sop_logits = outputs.sop_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.albert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + + prediction_scores = self.predictions(sequence_output) + sop_scores = self.sop_classifier(pooled_output) + + total_loss = None + if labels is not None and sentence_order_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1)) + total_loss = masked_lm_loss + sentence_order_loss + + if not return_dict: + output = (prediction_scores, sop_scores) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return AlbertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + sop_logits=sop_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class AlbertMLMHead(nn.Module): + def __init__(self, config: AlbertConfig): + super().__init__() + + self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + self.dense = nn.Linear(config.hidden_size, config.embedding_size) + self.decoder = nn.Linear(config.embedding_size, config.vocab_size) + self.activation = ACT2FN[config.hidden_act] + self.decoder.bias = self.bias + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.decoder(hidden_states) + + prediction_scores = hidden_states + + return prediction_scores + + def _tie_weights(self) -> None: + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias + + +class AlbertSOPHead(nn.Module): + def __init__(self, config: AlbertConfig): + super().__init__() + + self.dropout = nn.Dropout(config.classifier_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, pooled_output: torch.Tensor) -> torch.Tensor: + dropout_pooled_output = self.dropout(pooled_output) + logits = self.classifier(dropout_pooled_output) + return logits + + +@auto_docstring +class AlbertForMaskedLM(AlbertPreTrainedModel): + _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] + + def __init__(self, config): + super().__init__(config) + + self.albert = AlbertModel(config, add_pooling_layer=False) + self.predictions = AlbertMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self) -> nn.Linear: + return self.predictions.decoder + + def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: + self.predictions.decoder = new_embeddings + self.predictions.bias = new_embeddings.bias + + def get_input_embeddings(self) -> nn.Embedding: + return self.albert.embeddings.word_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[MaskedLMOutput, tuple]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, AlbertForMaskedLM + + >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2") + >>> model = AlbertForMaskedLM.from_pretrained("albert/albert-base-v2") + + >>> # add mask_token + >>> inputs = tokenizer("The capital of [MASK] is Paris.", return_tensors="pt") + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> # retrieve index of [MASK] + >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0] + >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1) + >>> tokenizer.decode(predicted_token_id) + 'france' + ``` + + ```python + >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"] + >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100) + >>> outputs = model(**inputs, labels=labels) + >>> round(outputs.loss.item(), 2) + 0.81 + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.albert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_outputs = outputs[0] + + prediction_scores = self.predictions(sequence_outputs) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """ +) +class AlbertForSequenceClassification(AlbertPreTrainedModel): + def __init__(self, config: AlbertConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.albert = AlbertModel(config) + self.dropout = nn.Dropout(config.classifier_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[SequenceClassifierOutput, tuple]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.albert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class AlbertForTokenClassification(AlbertPreTrainedModel): + def __init__(self, config: AlbertConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.albert = AlbertModel(config, add_pooling_layer=False) + classifier_dropout_prob = ( + config.classifier_dropout_prob + if config.classifier_dropout_prob is not None + else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[TokenClassifierOutput, tuple]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.albert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class AlbertForQuestionAnswering(AlbertPreTrainedModel): + def __init__(self, config: AlbertConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.albert = AlbertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[AlbertForPreTrainingOutput, tuple]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.albert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits: torch.Tensor = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class AlbertForMultipleChoice(AlbertPreTrainedModel): + def __init__(self, config: AlbertConfig): + super().__init__(config) + + self.albert = AlbertModel(config) + self.dropout = nn.Dropout(config.classifier_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[AlbertForPreTrainingOutput, tuple]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see + *input_ids* above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + outputs = self.albert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits: torch.Tensor = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "load_tf_weights_in_albert", + "AlbertPreTrainedModel", + "AlbertModel", + "AlbertForPreTraining", + "AlbertForMaskedLM", + "AlbertForSequenceClassification", + "AlbertForTokenClassification", + "AlbertForQuestionAnswering", + "AlbertForMultipleChoice", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_flax_albert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_flax_albert.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f19cb27716fb3f8846ef88e870e3eb1188a4bf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_flax_albert.py @@ -0,0 +1,1132 @@ +# coding=utf-8 +# Copyright 2021 Google AI, Google Brain and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPooling, + FlaxMaskedLMOutput, + FlaxMultipleChoiceModelOutput, + FlaxQuestionAnsweringModelOutput, + FlaxSequenceClassifierOutput, + FlaxTokenClassifierOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_albert import AlbertConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "albert/albert-base-v2" +_CONFIG_FOR_DOC = "AlbertConfig" + + +@flax.struct.dataclass +class FlaxAlbertForPreTrainingOutput(ModelOutput): + """ + Output type of [`FlaxAlbertForPreTraining`]. + + Args: + prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + sop_logits (`jnp.ndarray` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + prediction_logits: jnp.ndarray = None + sop_logits: jnp.ndarray = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + + +ALBERT_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`AlbertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +ALBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + +""" + + +class FlaxAlbertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.config.embedding_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.position_embeddings = nn.Embed( + self.config.max_position_embeddings, + self.config.embedding_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.token_type_embeddings = nn.Embed( + self.config.type_vocab_size, + self.config.embedding_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, input_ids, token_type_ids, position_ids, deterministic: bool = True): + # Embed + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + position_embeds = self.position_embeddings(position_ids.astype("i4")) + token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) + + # Sum all embeddings + hidden_states = inputs_embeds + token_type_embeddings + position_embeds + + # Layer Norm + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxAlbertSelfAttention(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " + " : {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): + head_dim = self.config.hidden_size // self.config.num_attention_heads + + query_states = self.query(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + value_states = self.value(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + key_states = self.key(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + projected_attn_output = self.dense(attn_output) + projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic) + layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states) + outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,) + return outputs + + +class FlaxAlbertLayer(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype) + self.ffn = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + self.ffn_output = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + ): + attention_outputs = self.attention( + hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions + ) + attention_output = attention_outputs[0] + ffn_output = self.ffn(attention_output) + ffn_output = self.activation(ffn_output) + ffn_output = self.ffn_output(ffn_output) + ffn_output = self.dropout(ffn_output, deterministic=deterministic) + hidden_states = self.full_layer_layer_norm(ffn_output + attention_output) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + return outputs + + +class FlaxAlbertLayerCollection(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num) + ] + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + ): + layer_hidden_states = () + layer_attentions = () + + for layer_index, albert_layer in enumerate(self.layers): + layer_output = albert_layer( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + ) + hidden_states = layer_output[0] + + if output_attentions: + layer_attentions = layer_attentions + (layer_output[1],) + + if output_hidden_states: + layer_hidden_states = layer_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (layer_hidden_states,) + if output_attentions: + outputs = outputs + (layer_attentions,) + return outputs # last-layer hidden state, (layer hidden states), (layer attentions) + + +class FlaxAlbertLayerCollections(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + layer_index: Optional[str] = None + + def setup(self): + self.albert_layers = FlaxAlbertLayerCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + ): + outputs = self.albert_layers( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + return outputs + + +class FlaxAlbertLayerGroups(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_groups) + ] + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = (hidden_states,) if output_hidden_states else None + + for i in range(self.config.num_hidden_layers): + # Index of the hidden group + group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups)) + layer_group_output = self.layers[group_idx]( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = layer_group_output[0] + + if output_attentions: + all_attentions = all_attentions + layer_group_output[-1] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxAlbertEncoder(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.embedding_hidden_mapping_in = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + hidden_states = self.embedding_hidden_mapping_in(hidden_states) + return self.albert_layer_groups( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + +class FlaxAlbertOnlyMLMHead(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype) + self.activation = ACT2FN[self.config.hidden_act] + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) + self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) + + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + + if shared_embedding is not None: + hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + hidden_states = self.decoder(hidden_states) + + hidden_states += self.bias + return hidden_states + + +class FlaxAlbertSOPHead(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dropout = nn.Dropout(self.config.classifier_dropout_prob) + self.classifier = nn.Dense(2, dtype=self.dtype) + + def __call__(self, pooled_output, deterministic=True): + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + return logits + + +class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = AlbertConfig + base_model_prefix = "albert" + module_class: nn.Module = None + + def __init__( + self, + config: AlbertConfig, + input_shape: tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + token_type_ids = jnp.zeros_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) + attention_mask = jnp.ones_like(input_ids) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + params: Optional[dict] = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # init input tensors if not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + ) + + +class FlaxAlbertModule(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True + + def setup(self): + self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxAlbertEncoder(self.config, dtype=self.dtype) + if self.add_pooling_layer: + self.pooler = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + name="pooler", + ) + self.pooler_activation = nn.tanh + else: + self.pooler = None + self.pooler_activation = None + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids: Optional[np.ndarray] = None, + position_ids: Optional[np.ndarray] = None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # make sure `token_type_ids` is correctly initialized when not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + # make sure `position_ids` is correctly initialized when not passed + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic) + + outputs = self.encoder( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + if self.add_pooling_layer: + pooled = self.pooler(hidden_states[:, 0]) + pooled = self.pooler_activation(pooled) + else: + pooled = None + + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] + + return FlaxBaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "The bare Albert Model transformer outputting raw hidden-states without any specific head on top.", + ALBERT_START_DOCSTRING, +) +class FlaxAlbertModel(FlaxAlbertPreTrainedModel): + module_class = FlaxAlbertModule + + +append_call_sample_docstring(FlaxAlbertModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC) + + +class FlaxAlbertForPreTrainingModule(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype) + self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype) + self.sop_classifier = FlaxAlbertSOPHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.albert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.tie_word_embeddings: + shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + hidden_states = outputs[0] + pooled_output = outputs[1] + + prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding) + sop_scores = self.sop_classifier(pooled_output, deterministic=deterministic) + + if not return_dict: + return (prediction_scores, sop_scores) + outputs[2:] + + return FlaxAlbertForPreTrainingOutput( + prediction_logits=prediction_scores, + sop_logits=sop_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a + `sentence order prediction (classification)` head. + """, + ALBERT_START_DOCSTRING, +) +class FlaxAlbertForPreTraining(FlaxAlbertPreTrainedModel): + module_class = FlaxAlbertForPreTrainingModule + + +FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxAlbertForPreTraining + + >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2") + >>> model = FlaxAlbertForPreTraining.from_pretrained("albert/albert-base-v2") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.sop_logits + ``` +""" + +overwrite_call_docstring( + FlaxAlbertForPreTraining, + ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING, +) +append_replace_return_docstrings( + FlaxAlbertForPreTraining, output_type=FlaxAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC +) + + +class FlaxAlbertForMaskedLMModule(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.albert = FlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.albert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.predictions(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING) +class FlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel): + module_class = FlaxAlbertForMaskedLMModule + + +append_call_sample_docstring( + FlaxAlbertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC, revision="refs/pr/11" +) + + +class FlaxAlbertForSequenceClassificationModule(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype) + classifier_dropout = ( + self.config.classifier_dropout_prob + if self.config.classifier_dropout_prob is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) + self.classifier = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + ) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.albert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + if not return_dict: + return (logits,) + outputs[2:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + ALBERT_START_DOCSTRING, +) +class FlaxAlbertForSequenceClassification(FlaxAlbertPreTrainedModel): + module_class = FlaxAlbertForSequenceClassificationModule + + +append_call_sample_docstring( + FlaxAlbertForSequenceClassification, + _CHECKPOINT_FOR_DOC, + FlaxSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxAlbertForMultipleChoiceModule(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense(1, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + num_choices = input_ids.shape[1] + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None + token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None + position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None + + # Model + outputs = self.albert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.reshape(-1, num_choices) + + if not return_dict: + return (reshaped_logits,) + outputs[2:] + + return FlaxMultipleChoiceModelOutput( + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ALBERT_START_DOCSTRING, +) +class FlaxAlbertForMultipleChoice(FlaxAlbertPreTrainedModel): + module_class = FlaxAlbertForMultipleChoiceModule + + +overwrite_call_docstring( + FlaxAlbertForMultipleChoice, ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") +) +append_call_sample_docstring( + FlaxAlbertForMultipleChoice, + _CHECKPOINT_FOR_DOC, + FlaxMultipleChoiceModelOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxAlbertForTokenClassificationModule(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + classifier_dropout = ( + self.config.classifier_dropout_prob + if self.config.classifier_dropout_prob is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.albert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.classifier(hidden_states) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxTokenClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ALBERT_START_DOCSTRING, +) +class FlaxAlbertForTokenClassification(FlaxAlbertPreTrainedModel): + module_class = FlaxAlbertForTokenClassificationModule + + +append_call_sample_docstring( + FlaxAlbertForTokenClassification, + _CHECKPOINT_FOR_DOC, + FlaxTokenClassifierOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxAlbertForQuestionAnsweringModule(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.albert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + return (start_logits, end_logits) + outputs[1:] + + return FlaxQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ALBERT_START_DOCSTRING, +) +class FlaxAlbertForQuestionAnswering(FlaxAlbertPreTrainedModel): + module_class = FlaxAlbertForQuestionAnsweringModule + + +append_call_sample_docstring( + FlaxAlbertForQuestionAnswering, + _CHECKPOINT_FOR_DOC, + FlaxQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) + +__all__ = [ + "FlaxAlbertPreTrainedModel", + "FlaxAlbertModel", + "FlaxAlbertForPreTraining", + "FlaxAlbertForMaskedLM", + "FlaxAlbertForSequenceClassification", + "FlaxAlbertForMultipleChoice", + "FlaxAlbertForTokenClassification", + "FlaxAlbertForQuestionAnswering", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_tf_albert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_tf_albert.py new file mode 100644 index 0000000000000000000000000000000000000000..101ab63dc0545992fe68a53205f4ad81c607d9ca --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_tf_albert.py @@ -0,0 +1,1572 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 ALBERT model.""" + +from __future__ import annotations + +import math +from dataclasses import dataclass + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPooling, + TFMaskedLMOutput, + TFMultipleChoiceModelOutput, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFMultipleChoiceLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_albert import AlbertConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "albert/albert-base-v2" +_CONFIG_FOR_DOC = "AlbertConfig" + + +class TFAlbertPreTrainingLoss: + """ + Loss function suitable for ALBERT pretraining, that is, the task of pretraining a language model by combining SOP + + MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. + """ + + def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor: + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) + if self.config.tf_legacy_loss: + # make sure only labels that are not equal to -100 + # are taken into account as loss + masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100) + masked_lm_reduced_logits = tf.boolean_mask( + tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])), + mask=masked_lm_active_loss, + ) + masked_lm_labels = tf.boolean_mask( + tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss + ) + sentence_order_active_loss = tf.not_equal( + tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100 + ) + sentence_order_reduced_logits = tf.boolean_mask( + tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss + ) + sentence_order_label = tf.boolean_mask( + tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss + ) + masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits) + sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits) + masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0])) + masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0) + + return masked_lm_loss + sentence_order_loss + + # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway + unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0]) + # make sure only labels that are not equal to -100 + # are taken into account for the loss computation + lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype) + masked_lm_losses = unmasked_lm_losses * lm_loss_mask + reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask) + + sop_logits = tf.reshape(logits[1], (-1, 2)) + # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway + unmasked_sop_loss = loss_fn(y_true=tf.nn.relu(labels["sentence_order_label"]), y_pred=sop_logits) + sop_loss_mask = tf.cast(labels["sentence_order_label"] != -100, dtype=unmasked_sop_loss.dtype) + + masked_sop_loss = unmasked_sop_loss * sop_loss_mask + reduced_masked_sop_loss = tf.reduce_sum(masked_sop_loss) / tf.reduce_sum(sop_loss_mask) + + return tf.reshape(reduced_masked_lm_loss + reduced_masked_sop_loss, (1,)) + + +class TFAlbertEmbeddings(keras.layers.Layer): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config: AlbertConfig, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.embedding_size = config.embedding_size + self.max_position_embeddings = config.max_position_embeddings + self.initializer_range = config.initializer_range + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def build(self, input_shape=None): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.config.vocab_size, self.embedding_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("token_type_embeddings"): + self.token_type_embeddings = self.add_weight( + name="embeddings", + shape=[self.config.type_vocab_size, self.embedding_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.embedding_size], + initializer=get_initializer(self.initializer_range), + ) + + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.embedding_size]) + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call + def call( + self, + input_ids: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + token_type_ids: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + past_key_values_length=0, + training: bool = False, + ) -> tf.Tensor: + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + if input_ids is None and inputs_embeds is None: + raise ValueError("Need to provide either `input_ids` or `input_embeds`.") + + if input_ids is not None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + if position_ids is None: + position_ids = tf.expand_dims( + tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0 + ) + + position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) + token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) + final_embeddings = inputs_embeds + position_embeds + token_type_embeds + final_embeddings = self.LayerNorm(inputs=final_embeddings) + final_embeddings = self.dropout(inputs=final_embeddings, training=training) + + return final_embeddings + + +class TFAlbertAttention(keras.layers.Layer): + """Contains the complete attention sublayer, including both dropouts and layer norm.""" + + def __init__(self, config: AlbertConfig, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number " + f"of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + self.output_attentions = config.output_attentions + + self.query = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + # Two different dropout probabilities; see https://github.com/google-research/albert/blob/master/modeling.py#L971-L993 + self.attention_dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + self.output_dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + input_tensor: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + training: bool = False, + ) -> tuple[tf.Tensor]: + batch_size = shape_list(input_tensor)[0] + mixed_query_layer = self.query(inputs=input_tensor) + mixed_key_layer = self.key(inputs=input_tensor) + mixed_value_layer = self.value(inputs=input_tensor) + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) + value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # (batch size, num_heads, seq_len_q, seq_len_k) + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) + attention_scores = tf.divide(attention_scores, dk) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in TFAlbertModel call() function) + attention_scores = tf.add(attention_scores, attention_mask) + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(inputs=attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = tf.multiply(attention_probs, head_mask) + + context_layer = tf.matmul(attention_probs, value_layer) + context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, all_head_size) + context_layer = tf.reshape(tensor=context_layer, shape=(batch_size, -1, self.all_head_size)) + self_outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + hidden_states = self_outputs[0] + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.output_dropout(inputs=hidden_states, training=training) + attention_output = self.LayerNorm(inputs=hidden_states + input_tensor) + + # add attentions if we output them + outputs = (attention_output,) + self_outputs[1:] + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +class TFAlbertLayer(keras.layers.Layer): + def __init__(self, config: AlbertConfig, **kwargs): + super().__init__(**kwargs) + + self.attention = TFAlbertAttention(config, name="attention") + self.ffn = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn" + ) + + if isinstance(config.hidden_act, str): + self.activation = get_tf_activation(config.hidden_act) + else: + self.activation = config.hidden_act + + self.ffn_output = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn_output" + ) + self.full_layer_layer_norm = keras.layers.LayerNormalization( + epsilon=config.layer_norm_eps, name="full_layer_layer_norm" + ) + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + training: bool = False, + ) -> tuple[tf.Tensor]: + attention_outputs = self.attention( + input_tensor=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + training=training, + ) + ffn_output = self.ffn(inputs=attention_outputs[0]) + ffn_output = self.activation(ffn_output) + ffn_output = self.ffn_output(inputs=ffn_output) + ffn_output = self.dropout(inputs=ffn_output, training=training) + hidden_states = self.full_layer_layer_norm(inputs=ffn_output + attention_outputs[0]) + + # add attentions if we output them + outputs = (hidden_states,) + attention_outputs[1:] + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "ffn", None) is not None: + with tf.name_scope(self.ffn.name): + self.ffn.build([None, None, self.config.hidden_size]) + if getattr(self, "ffn_output", None) is not None: + with tf.name_scope(self.ffn_output.name): + self.ffn_output.build([None, None, self.config.intermediate_size]) + if getattr(self, "full_layer_layer_norm", None) is not None: + with tf.name_scope(self.full_layer_layer_norm.name): + self.full_layer_layer_norm.build([None, None, self.config.hidden_size]) + + +class TFAlbertLayerGroup(keras.layers.Layer): + def __init__(self, config: AlbertConfig, **kwargs): + super().__init__(**kwargs) + + self.albert_layers = [ + TFAlbertLayer(config, name=f"albert_layers_._{i}") for i in range(config.inner_group_num) + ] + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + output_hidden_states: bool, + training: bool = False, + ) -> TFBaseModelOutput | tuple[tf.Tensor]: + layer_hidden_states = () if output_hidden_states else None + layer_attentions = () if output_attentions else None + + for layer_index, albert_layer in enumerate(self.albert_layers): + if output_hidden_states: + layer_hidden_states = layer_hidden_states + (hidden_states,) + + layer_output = albert_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask[layer_index], + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_output[0] + + if output_attentions: + layer_attentions = layer_attentions + (layer_output[1],) + + # Add last layer + if output_hidden_states: + layer_hidden_states = layer_hidden_states + (hidden_states,) + + return tuple(v for v in [hidden_states, layer_hidden_states, layer_attentions] if v is not None) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "albert_layers", None) is not None: + for layer in self.albert_layers: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFAlbertTransformer(keras.layers.Layer): + def __init__(self, config: AlbertConfig, **kwargs): + super().__init__(**kwargs) + + self.num_hidden_layers = config.num_hidden_layers + self.num_hidden_groups = config.num_hidden_groups + # Number of layers in a hidden group + self.layers_per_group = int(config.num_hidden_layers / config.num_hidden_groups) + self.embedding_hidden_mapping_in = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + name="embedding_hidden_mapping_in", + ) + self.albert_layer_groups = [ + TFAlbertLayerGroup(config, name=f"albert_layer_groups_._{i}") for i in range(config.num_hidden_groups) + ] + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + output_hidden_states: bool, + return_dict: bool, + training: bool = False, + ) -> TFBaseModelOutput | tuple[tf.Tensor]: + hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states) + all_attentions = () if output_attentions else None + all_hidden_states = (hidden_states,) if output_hidden_states else None + + for i in range(self.num_hidden_layers): + # Index of the hidden group + group_idx = int(i / (self.num_hidden_layers / self.num_hidden_groups)) + layer_group_output = self.albert_layer_groups[group_idx]( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask[group_idx * self.layers_per_group : (group_idx + 1) * self.layers_per_group], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + training=training, + ) + hidden_states = layer_group_output[0] + + if output_attentions: + all_attentions = all_attentions + layer_group_output[-1] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embedding_hidden_mapping_in", None) is not None: + with tf.name_scope(self.embedding_hidden_mapping_in.name): + self.embedding_hidden_mapping_in.build([None, None, self.config.embedding_size]) + if getattr(self, "albert_layer_groups", None) is not None: + for layer in self.albert_layer_groups: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFAlbertPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = AlbertConfig + base_model_prefix = "albert" + + +class TFAlbertMLMHead(keras.layers.Layer): + def __init__(self, config: AlbertConfig, input_embeddings: keras.layers.Layer, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.embedding_size = config.embedding_size + self.dense = keras.layers.Dense( + config.embedding_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + if isinstance(config.hidden_act, str): + self.activation = get_tf_activation(config.hidden_act) + else: + self.activation = config.hidden_act + + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = input_embeddings + + def build(self, input_shape=None): + self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") + self.decoder_bias = self.add_weight( + shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias" + ) + + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.embedding_size]) + + def get_output_embeddings(self) -> keras.layers.Layer: + return self.decoder + + def set_output_embeddings(self, value: tf.Variable): + self.decoder.weight = value + self.decoder.vocab_size = shape_list(value)[0] + + def get_bias(self) -> dict[str, tf.Variable]: + return {"bias": self.bias, "decoder_bias": self.decoder_bias} + + def set_bias(self, value: tf.Variable): + self.bias = value["bias"] + self.decoder_bias = value["decoder_bias"] + self.config.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.LayerNorm(inputs=hidden_states) + seq_length = shape_list(tensor=hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size]) + hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.decoder_bias) + + return hidden_states + + +@keras_serializable +class TFAlbertMainLayer(keras.layers.Layer): + config_class = AlbertConfig + + def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True, **kwargs): + super().__init__(**kwargs) + + self.config = config + + self.embeddings = TFAlbertEmbeddings(config, name="embeddings") + self.encoder = TFAlbertTransformer(config, name="encoder") + self.pooler = ( + keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="pooler", + ) + if add_pooling_layer + else None + ) + + def get_input_embeddings(self) -> keras.layers.Layer: + return self.embeddings + + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = tf.fill(dims=input_shape, value=1) + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + training=training, + ) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(inputs=sequence_output[:, 0]) if self.pooler is not None else None + + if not return_dict: + return ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + return TFBaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build([None, None, self.config.hidden_size]) + + +@dataclass +class TFAlbertForPreTrainingOutput(ModelOutput): + """ + Output type of [`TFAlbertForPreTraining`]. + + Args: + prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + sop_logits (`tf.Tensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + prediction_logits: tf.Tensor | None = None + sop_logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | None = None + attentions: tuple[tf.Tensor] | None = None + + +ALBERT_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`AlbertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ALBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare Albert Model transformer outputting raw hidden-states without any specific head on top.", + ALBERT_START_DOCSTRING, +) +class TFAlbertModel(TFAlbertPreTrainedModel): + def __init__(self, config: AlbertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.albert = TFAlbertMainLayer(config, name="albert") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]: + outputs = self.albert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "albert", None) is not None: + with tf.name_scope(self.albert.name): + self.albert.build(None) + + +@add_start_docstrings( + """ + Albert Model with two heads on top for pretraining: a `masked language modeling` head and a `sentence order + prediction` (classification) head. + """, + ALBERT_START_DOCSTRING, +) +class TFAlbertForPreTraining(TFAlbertPreTrainedModel, TFAlbertPreTrainingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"predictions.decoder.weight"] + + def __init__(self, config: AlbertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.albert = TFAlbertMainLayer(config, name="albert") + self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions") + self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier") + + def get_lm_head(self) -> keras.layers.Layer: + return self.predictions + + @unpack_inputs + @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + sentence_order_label: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFAlbertForPreTrainingOutput | tuple[tf.Tensor]: + r""" + Return: + + Example: + + ```python + >>> import tensorflow as tf + >>> from transformers import AutoTokenizer, TFAlbertForPreTraining + + >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2") + >>> model = TFAlbertForPreTraining.from_pretrained("albert/albert-base-v2") + + >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] + >>> # Batch size 1 + >>> outputs = model(input_ids) + + >>> prediction_logits = outputs.prediction_logits + >>> sop_logits = outputs.sop_logits + ```""" + + outputs = self.albert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output, pooled_output = outputs[:2] + prediction_scores = self.predictions(hidden_states=sequence_output) + sop_scores = self.sop_classifier(pooled_output=pooled_output, training=training) + total_loss = None + + if labels is not None and sentence_order_label is not None: + d_labels = {"labels": labels} + d_labels["sentence_order_label"] = sentence_order_label + total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, sop_scores)) + + if not return_dict: + output = (prediction_scores, sop_scores) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return TFAlbertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + sop_logits=sop_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "albert", None) is not None: + with tf.name_scope(self.albert.name): + self.albert.build(None) + if getattr(self, "predictions", None) is not None: + with tf.name_scope(self.predictions.name): + self.predictions.build(None) + if getattr(self, "sop_classifier", None) is not None: + with tf.name_scope(self.sop_classifier.name): + self.sop_classifier.build(None) + + +class TFAlbertSOPHead(keras.layers.Layer): + def __init__(self, config: AlbertConfig, **kwargs): + super().__init__(**kwargs) + + self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob) + self.classifier = keras.layers.Dense( + units=config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="classifier", + ) + self.config = config + + def call(self, pooled_output: tf.Tensor, training: bool) -> tf.Tensor: + dropout_pooled_output = self.dropout(inputs=pooled_output, training=training) + logits = self.classifier(inputs=dropout_pooled_output) + + return logits + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING) +class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions.decoder.weight"] + + def __init__(self, config: AlbertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert") + self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions") + + def get_lm_head(self) -> keras.layers.Layer: + return self.predictions + + @unpack_inputs + @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFMaskedLMOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Returns: + + Example: + + ```python + >>> import tensorflow as tf + >>> from transformers import AutoTokenizer, TFAlbertForMaskedLM + + >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2") + >>> model = TFAlbertForMaskedLM.from_pretrained("albert/albert-base-v2") + + >>> # add mask_token + >>> inputs = tokenizer(f"The capital of [MASK] is Paris.", return_tensors="tf") + >>> logits = model(**inputs).logits + + >>> # retrieve index of [MASK] + >>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1] + >>> predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1) + >>> tokenizer.decode(predicted_token_id) + 'france' + ``` + + ```python + >>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"] + >>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100) + >>> outputs = model(**inputs, labels=labels) + >>> round(float(outputs.loss), 2) + 0.81 + ``` + """ + outputs = self.albert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + prediction_scores = self.predictions(hidden_states=sequence_output, training=training) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "albert", None) is not None: + with tf.name_scope(self.albert.name): + self.albert.build(None) + if getattr(self, "predictions", None) is not None: + with tf.name_scope(self.predictions.name): + self.predictions.build(None) + + +@add_start_docstrings( + """ + Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + ALBERT_START_DOCSTRING, +) +class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClassificationLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"predictions"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config: AlbertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.albert = TFAlbertMainLayer(config, name="albert") + self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob) + self.classifier = keras.layers.Dense( + units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="vumichien/albert-base-v2-imdb", + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="'LABEL_1'", + expected_loss=0.12, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs = self.albert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(inputs=pooled_output, training=training) + logits = self.classifier(inputs=pooled_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[2:] + + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "albert", None) is not None: + with tf.name_scope(self.albert.name): + self.albert.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ALBERT_START_DOCSTRING, +) +class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config: AlbertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert") + classifier_dropout_prob = ( + config.classifier_dropout_prob + if config.classifier_dropout_prob is not None + else config.hidden_dropout_prob + ) + self.dropout = keras.layers.Dropout(rate=classifier_dropout_prob) + self.classifier = keras.layers.Dense( + units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.albert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(inputs=sequence_output, training=training) + logits = self.classifier(inputs=sequence_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[2:] + + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "albert", None) is not None: + with tf.name_scope(self.albert.name): + self.albert.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ALBERT_START_DOCSTRING, +) +class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"] + + def __init__(self, config: AlbertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert") + self.qa_outputs = keras.layers.Dense( + units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="vumichien/albert-base-v2-squad2", + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + qa_target_start_index=12, + qa_target_end_index=13, + expected_output="'a nice puppet'", + expected_loss=7.36, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + start_positions: np.ndarray | tf.Tensor | None = None, + end_positions: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: + r""" + start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + outputs = self.albert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.qa_outputs(inputs=sequence_output) + start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1) + start_logits = tf.squeeze(input=start_logits, axis=-1) + end_logits = tf.squeeze(input=end_logits, axis=-1) + loss = None + + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "albert", None) is not None: + with tf.name_scope(self.albert.name): + self.albert.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ALBERT_START_DOCSTRING, +) +class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config: AlbertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.albert = TFAlbertMainLayer(config, name="albert") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.classifier = keras.layers.Dense( + units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) + """ + + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = ( + tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None + ) + flat_token_type_ids = ( + tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None + ) + flat_position_ids = ( + tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None + ) + flat_inputs_embeds = ( + tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None + else None + ) + outputs = self.albert( + input_ids=flat_input_ids, + attention_mask=flat_attention_mask, + token_type_ids=flat_token_type_ids, + position_ids=flat_position_ids, + head_mask=head_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(inputs=pooled_output, training=training) + logits = self.classifier(inputs=pooled_output) + reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "albert", None) is not None: + with tf.name_scope(self.albert.name): + self.albert.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +__all__ = [ + "TFAlbertPreTrainedModel", + "TFAlbertModel", + "TFAlbertForPreTraining", + "TFAlbertForMaskedLM", + "TFAlbertForSequenceClassification", + "TFAlbertForTokenClassification", + "TFAlbertForQuestionAnswering", + "TFAlbertForMultipleChoice", + "TFAlbertMainLayer", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/tokenization_albert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/tokenization_albert.py new file mode 100644 index 0000000000000000000000000000000000000000..011ad689edbdb10f53694eb9c774604d922a0d73 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/tokenization_albert.py @@ -0,0 +1,320 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for ALBERT model.""" + +import os +import unicodedata +from shutil import copyfile +from typing import Any, Optional + +import sentencepiece as spm + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} + + +SPIECE_UNDERLINE = "▁" + + +@requires(backends=("sentencepiece",)) +class AlbertTokenizer(PreTrainedTokenizer): + """ + Construct an ALBERT tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + remove_space (`bool`, *optional*, defaults to `True`): + Whether or not to strip the text when tokenizing (removing excess spaces before and after the string). + keep_accents (`bool`, *optional*, defaults to `False`): + Whether or not to keep accents when tokenizing. + bos_token (`str`, *optional*, defaults to `"[CLS]"`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `"[SEP]"`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + do_lower_case=True, + remove_space=True, + keep_accents=False, + bos_token="[CLS]", + eos_token="[SEP]", + unk_token="", + sep_token="[SEP]", + pad_token="", + cls_token="[CLS]", + mask_token="[MASK]", + sp_model_kwargs: Optional[dict[str, Any]] = None, + **kwargs, + ) -> None: + # Mask token behave like a normal word, i.e. include the space before it and + # is included in the raw text, there should be a match in a non-normalized sentence. + mask_token = ( + AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False) + if isinstance(mask_token, str) + else mask_token + ) + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.do_lower_case = do_lower_case + self.remove_space = remove_space + self.keep_accents = keep_accents + self.vocab_file = vocab_file + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + super().__init__( + do_lower_case=do_lower_case, + remove_space=remove_space, + keep_accents=keep_accents, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + @property + def vocab_size(self) -> int: + return len(self.sp_model) + + def get_vocab(self) -> dict[str, int]: + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def preprocess_text(self, inputs): + if self.remove_space: + outputs = " ".join(inputs.strip().split()) + else: + outputs = inputs + outputs = outputs.replace("``", '"').replace("''", '"') + + if not self.keep_accents: + outputs = unicodedata.normalize("NFKD", outputs) + outputs = "".join([c for c in outputs if not unicodedata.combining(c)]) + if self.do_lower_case: + outputs = outputs.lower() + + return outputs + + def _tokenize(self, text: str) -> list[str]: + """Tokenize a string.""" + text = self.preprocess_text(text) + pieces = self.sp_model.encode(text, out_type=str) + new_pieces = [] + for piece in pieces: + if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit(): + # Logic to handle special cases see https://github.com/google-research/bert/blob/master/README.md#tokenization + # `9,9` -> ['▁9', ',', '9'] instead of [`_9,`, '9'] + cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, "")) + if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: + if len(cur_pieces[0]) == 1: + cur_pieces = cur_pieces[1:] + else: + cur_pieces[0] = cur_pieces[0][1:] + cur_pieces.append(piece[-1]) + new_pieces.extend(cur_pieces) + else: + new_pieces.append(piece) + + return new_pieces + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.PieceToId(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_model.IdToPiece(index) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An ALBERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return cls + token_ids_0 + sep + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + +__all__ = ["AlbertTokenizer"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/tokenization_albert_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/tokenization_albert_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..ed9add51d20743948dc1fe51ad6f5fe0c1ed1543 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/tokenization_albert_fast.py @@ -0,0 +1,178 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for ALBERT model.""" + +import os +from shutil import copyfile +from typing import Optional + +from ...tokenization_utils import AddedToken +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import is_sentencepiece_available, logging + + +if is_sentencepiece_available(): + from .tokenization_albert import AlbertTokenizer +else: + AlbertTokenizer = None + +logger = logging.get_logger(__name__) +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} + + +SPIECE_UNDERLINE = "▁" + + +class AlbertTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" ALBERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on + [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). This + tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + remove_space (`bool`, *optional*, defaults to `True`): + Whether or not to strip the text when tokenizing (removing excess spaces before and after the string). + keep_accents (`bool`, *optional*, defaults to `False`): + Whether or not to keep accents when tokenizing. + bos_token (`str`, *optional*, defaults to `"[CLS]"`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `"[SEP]"`): + The end of sequence token. .. note:: When building a sequence using special tokens, this is not the token + that is used for the end of sequence. The token used is the `sep_token`. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = AlbertTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + remove_space=True, + keep_accents=False, + bos_token="[CLS]", + eos_token="[SEP]", + unk_token="", + sep_token="[SEP]", + pad_token="", + cls_token="[CLS]", + mask_token="[MASK]", + **kwargs, + ): + # Mask token behave like a normal word, i.e. include the space before it and + # is included in the raw text, there should be a match in a non-normalized sentence. + mask_token = ( + AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False) + if isinstance(mask_token, str) + else mask_token + ) + + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + remove_space=remove_space, + keep_accents=keep_accents, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs, + ) + + self.do_lower_case = do_lower_case + self.remove_space = remove_space + self.keep_accents = keep_accents + self.vocab_file = vocab_file + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An ALBERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return cls + token_ids_0 + sep + return cls + token_ids_0 + sep + token_ids_1 + sep + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + +__all__ = ["AlbertTokenizerFast"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34a6ae1e5c2e4f042c141624bd2296587e9f811d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .auto_factory import * + from .configuration_auto import * + from .feature_extraction_auto import * + from .image_processing_auto import * + from .modeling_auto import * + from .modeling_flax_auto import * + from .modeling_tf_auto import * + from .processing_auto import * + from .tokenization_auto import * + from .video_processing_auto import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..a8781c8042a6f51aa2b68f03a847f6a6320ec9ba --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py @@ -0,0 +1,882 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Factory function to build auto-model classes.""" + +import copy +import importlib +import json +import os +import warnings +from collections import OrderedDict +from collections.abc import Iterator +from typing import Any, TypeVar, Union + +from ...configuration_utils import PretrainedConfig +from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from ...utils import ( + CONFIG_NAME, + cached_file, + copy_func, + extract_commit_hash, + find_adapter_config_file, + is_peft_available, + is_torch_available, + logging, + requires_backends, +) +from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings + + +if is_torch_available(): + from ...generation import GenerationMixin + + +logger = logging.get_logger(__name__) + +_T = TypeVar("_T") +# Tokenizers will depend on packages installed, too much variance and there are no common base or Protocol +_LazyAutoMappingValue = tuple[Union[type[Any], None], Union[type[Any], None]] + +CLASS_DOCSTRING = """ + This is a generic model class that will be instantiated as one of the model classes of the library when created + with the [`~BaseAutoModelClass.from_pretrained`] class method or the [`~BaseAutoModelClass.from_config`] class + method. + + This class cannot be instantiated directly using `__init__()` (throws an error). +""" + +FROM_CONFIG_DOCSTRING = """ + Instantiates one of the model classes of the library from a configuration. + + Note: + Loading a model from its configuration file does **not** load the model weights. It only affects the + model's configuration. Use [`~BaseAutoModelClass.from_pretrained`] to load the model weights. + + Args: + config ([`PretrainedConfig`]): + The model class to instantiate is selected based on the configuration class: + + List options + attn_implementation (`str`, *optional*): + The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. + + Examples: + + ```python + >>> from transformers import AutoConfig, BaseAutoModelClass + + >>> # Download configuration from huggingface.co and cache. + >>> config = AutoConfig.from_pretrained("checkpoint_placeholder") + >>> model = BaseAutoModelClass.from_config(config) + ``` +""" + +FROM_PRETRAINED_TORCH_DOCSTRING = """ + Instantiate one of the model classes of the library from a pretrained model. + + The model class to instantiate is selected based on the `model_type` property of the config object (either + passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by + falling back to using pattern matching on `pretrained_model_name_or_path`: + + List options + + The model is set in evaluation mode by default using `model.eval()` (so for instance, dropout modules are + deactivated). To train the model, you should first set it back in training mode with `model.train()` + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + model_args (additional positional arguments, *optional*): + Will be passed along to the underlying model `__init__()` method. + config ([`PretrainedConfig`], *optional*): + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + state_dict (*dict[str, torch.Tensor]*, *optional*): + A state dictionary to use instead of a state dictionary loaded from saved weights file. + + This option can be used if you want to create a model from a pretrained configuration but load your own + weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and + [`~PreTrainedModel.from_pretrained`] is not a simpler option. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_tf (`bool`, *optional*, defaults to `False`): + Load the model weights from a TensorFlow checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (e.g., not try downloading the model). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + code_revision (`str`, *optional*, defaults to `"main"`): + The specific revision to use for the code on the Hub, if the code leaves in a different repository than + the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based + system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier + allowed by git. + kwargs (additional keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + Examples: + + ```python + >>> from transformers import AutoConfig, BaseAutoModelClass + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder") + + >>> # Update configuration during loading + >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True) + >>> model.config.output_attentions + True + + >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) + >>> config = AutoConfig.from_pretrained("./tf_model/shortcut_placeholder_tf_model_config.json") + >>> model = BaseAutoModelClass.from_pretrained( + ... "./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index", from_tf=True, config=config + ... ) + ``` +""" + +FROM_PRETRAINED_TF_DOCSTRING = """ + Instantiate one of the model classes of the library from a pretrained model. + + The model class to instantiate is selected based on the `model_type` property of the config object (either + passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by + falling back to using pattern matching on `pretrained_model_name_or_path`: + + List options + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this + case, `from_pt` should be set to `True` and a configuration object should be provided as `config` + argument. This loading path is slower than converting the PyTorch model in a TensorFlow model + using the provided conversion scripts and loading the TensorFlow model afterwards. + model_args (additional positional arguments, *optional*): + Will be passed along to the underlying model `__init__()` method. + config ([`PretrainedConfig`], *optional*): + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_pt (`bool`, *optional*, defaults to `False`): + Load the model weights from a PyTorch checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (e.g., not try downloading the model). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + code_revision (`str`, *optional*, defaults to `"main"`): + The specific revision to use for the code on the Hub, if the code leaves in a different repository than + the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based + system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier + allowed by git. + kwargs (additional keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + Examples: + + ```python + >>> from transformers import AutoConfig, BaseAutoModelClass + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder") + + >>> # Update configuration during loading + >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True) + >>> model.config.output_attentions + True + + >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) + >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json") + >>> model = BaseAutoModelClass.from_pretrained( + ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config + ... ) + ``` +""" + +FROM_PRETRAINED_FLAX_DOCSTRING = """ + Instantiate one of the model classes of the library from a pretrained model. + + The model class to instantiate is selected based on the `model_type` property of the config object (either + passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by + falling back to using pattern matching on `pretrained_model_name_or_path`: + + List options + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this + case, `from_pt` should be set to `True` and a configuration object should be provided as `config` + argument. This loading path is slower than converting the PyTorch model in a TensorFlow model + using the provided conversion scripts and loading the TensorFlow model afterwards. + model_args (additional positional arguments, *optional*): + Will be passed along to the underlying model `__init__()` method. + config ([`PretrainedConfig`], *optional*): + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_pt (`bool`, *optional*, defaults to `False`): + Load the model weights from a PyTorch checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (e.g., not try downloading the model). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + code_revision (`str`, *optional*, defaults to `"main"`): + The specific revision to use for the code on the Hub, if the code leaves in a different repository than + the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based + system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier + allowed by git. + kwargs (additional keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + Examples: + + ```python + >>> from transformers import AutoConfig, BaseAutoModelClass + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder") + + >>> # Update configuration during loading + >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True) + >>> model.config.output_attentions + True + + >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) + >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json") + >>> model = BaseAutoModelClass.from_pretrained( + ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config + ... ) + ``` +""" + + +def _get_model_class(config, model_mapping): + supported_models = model_mapping[type(config)] + if not isinstance(supported_models, (list, tuple)): + return supported_models + + name_to_model = {model.__name__: model for model in supported_models} + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in name_to_model: + return name_to_model[arch] + elif f"TF{arch}" in name_to_model: + return name_to_model[f"TF{arch}"] + elif f"Flax{arch}" in name_to_model: + return name_to_model[f"Flax{arch}"] + + # If not architecture is set in the config or match the supported models, the first element of the tuple is the + # defaults. + return supported_models[0] + + +class _BaseAutoModelClass: + # Base class for auto models. + _model_mapping = None + + def __init__(self, *args, **kwargs) -> None: + raise OSError( + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " + f"`{self.__class__.__name__}.from_config(config)` methods." + ) + + @classmethod + def from_config(cls, config, **kwargs): + trust_remote_code = kwargs.pop("trust_remote_code", None) + has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map + has_local_code = type(config) in cls._model_mapping + if has_remote_code: + class_ref = config.auto_map[cls.__name__] + if "--" in class_ref: + upstream_repo = class_ref.split("--")[0] + else: + upstream_repo = None + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, config._name_or_path, has_local_code, has_remote_code, upstream_repo=upstream_repo + ) + + if has_remote_code and trust_remote_code: + if "--" in class_ref: + repo_id, class_ref = class_ref.split("--") + else: + repo_id = config.name_or_path + model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs) + # This block handles the case where the user is loading a model with `trust_remote_code=True` + # but a library model exists with the same name. We don't want to override the autoclass + # mappings in this case, or all future loads of that model will be the remote code model. + if not has_local_code: + cls.register(config.__class__, model_class, exist_ok=True) + model_class.register_for_auto_class(auto_class=cls) + _ = kwargs.pop("code_revision", None) + model_class = add_generation_mixin_to_remote_model(model_class) + return model_class._from_config(config, **kwargs) + elif type(config) in cls._model_mapping: + model_class = _get_model_class(config, cls._model_mapping) + return model_class._from_config(config, **kwargs) + + raise ValueError( + f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" + f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping)}." + ) + + @classmethod + def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig: + """Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses.""" + return config + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], *model_args, **kwargs): + config = kwargs.pop("config", None) + trust_remote_code = kwargs.get("trust_remote_code") + kwargs["_from_auto"] = True + hub_kwargs_names = [ + "cache_dir", + "force_download", + "local_files_only", + "proxies", + "resume_download", + "revision", + "subfolder", + "use_auth_token", + "token", + ] + hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} + code_revision = kwargs.pop("code_revision", None) + commit_hash = kwargs.pop("_commit_hash", None) + adapter_kwargs = kwargs.pop("adapter_kwargs", None) + + token = hub_kwargs.pop("token", None) + use_auth_token = hub_kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + hub_kwargs["token"] = token + + if commit_hash is None: + if not isinstance(config, PretrainedConfig): + # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible + resolved_config_file = cached_file( + pretrained_model_name_or_path, + CONFIG_NAME, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + **hub_kwargs, + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + else: + commit_hash = getattr(config, "_commit_hash", None) + + if is_peft_available(): + if adapter_kwargs is None: + adapter_kwargs = {} + if token is not None: + adapter_kwargs["token"] = token + + maybe_adapter_path = find_adapter_config_file( + pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs + ) + + if maybe_adapter_path is not None: + with open(maybe_adapter_path, "r", encoding="utf-8") as f: + adapter_config = json.load(f) + + adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path + pretrained_model_name_or_path = adapter_config["base_model_name_or_path"] + + if not isinstance(config, PretrainedConfig): + kwargs_orig = copy.deepcopy(kwargs) + # ensure not to pollute the config object with dtype="auto" - since it's + # meaningless in the context of the config object - torch.dtype values are acceptable + if kwargs.get("torch_dtype") == "auto": + _ = kwargs.pop("torch_dtype") + if kwargs.get("dtype") == "auto": + _ = kwargs.pop("dtype") + # to not overwrite the quantization_config if config has a quantization_config + if kwargs.get("quantization_config") is not None: + _ = kwargs.pop("quantization_config") + + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + return_unused_kwargs=True, + code_revision=code_revision, + _commit_hash=commit_hash, + **hub_kwargs, + **kwargs, + ) + + # if torch_dtype=auto was passed here, ensure to pass it on + if kwargs_orig.get("torch_dtype", None) == "auto": + kwargs["torch_dtype"] = "auto" + if kwargs_orig.get("dtype", None) == "auto": + kwargs["dtype"] = "auto" + if kwargs_orig.get("quantization_config", None) is not None: + kwargs["quantization_config"] = kwargs_orig["quantization_config"] + + has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map + has_local_code = type(config) in cls._model_mapping + upstream_repo = None + if has_remote_code: + class_ref = config.auto_map[cls.__name__] + if "--" in class_ref: + upstream_repo = class_ref.split("--")[0] + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, + pretrained_model_name_or_path, + has_local_code, + has_remote_code, + upstream_repo=upstream_repo, + ) + kwargs["trust_remote_code"] = trust_remote_code + + # Set the adapter kwargs + kwargs["adapter_kwargs"] = adapter_kwargs + + if has_remote_code and trust_remote_code: + model_class = get_class_from_dynamic_module( + class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs + ) + _ = hub_kwargs.pop("code_revision", None) + # This block handles the case where the user is loading a model with `trust_remote_code=True` + # but a library model exists with the same name. We don't want to override the autoclass + # mappings in this case, or all future loads of that model will be the remote code model. + if not has_local_code: + cls.register(config.__class__, model_class, exist_ok=True) + model_class.register_for_auto_class(auto_class=cls) + model_class = add_generation_mixin_to_remote_model(model_class) + return model_class.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs + ) + elif type(config) in cls._model_mapping: + model_class = _get_model_class(config, cls._model_mapping) + if model_class.config_class == config.sub_configs.get("text_config", None): + config = config.get_text_config() + return model_class.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs + ) + raise ValueError( + f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" + f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping)}." + ) + + @classmethod + def register(cls, config_class, model_class, exist_ok=False) -> None: + """ + Register a new model for this class. + + Args: + config_class ([`PretrainedConfig`]): + The configuration corresponding to the model to register. + model_class ([`PreTrainedModel`]): + The model to register. + """ + if hasattr(model_class, "config_class") and model_class.config_class.__name__ != config_class.__name__: + raise ValueError( + "The model class you are passing has a `config_class` attribute that is not consistent with the " + f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix " + "one of those so they match!" + ) + cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok) + + +class _BaseAutoBackboneClass(_BaseAutoModelClass): + # Base class for auto backbone models. + _model_mapping = None + + @classmethod + def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + requires_backends(cls, ["vision", "timm"]) + from ...models.timm_backbone import TimmBackboneConfig + + config = kwargs.pop("config", TimmBackboneConfig()) + + if kwargs.get("out_features") is not None: + raise ValueError("Cannot specify `out_features` for timm backbones") + + if kwargs.get("output_loading_info", False): + raise ValueError("Cannot specify `output_loading_info=True` when loading from timm") + + num_channels = kwargs.pop("num_channels", config.num_channels) + features_only = kwargs.pop("features_only", config.features_only) + use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone) + out_indices = kwargs.pop("out_indices", config.out_indices) + config = TimmBackboneConfig( + backbone=pretrained_model_name_or_path, + num_channels=num_channels, + features_only=features_only, + use_pretrained_backbone=use_pretrained_backbone, + out_indices=out_indices, + ) + return super().from_config(config, **kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + use_timm_backbone = kwargs.pop("use_timm_backbone", False) + if use_timm_backbone: + return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + +def insert_head_doc(docstring, head_doc: str = ""): + if len(head_doc) > 0: + return docstring.replace( + "one of the model classes of the library ", + f"one of the model classes of the library (with a {head_doc} head) ", + ) + return docstring.replace( + "one of the model classes of the library ", "one of the base model classes of the library " + ) + + +def auto_class_update(cls, checkpoint_for_example: str = "google-bert/bert-base-cased", head_doc: str = ""): + # Create a new class with the right name from the base class + model_mapping = cls._model_mapping + name = cls.__name__ + class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc) + cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name) + + # Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't + # have a specific docstrings for them. + from_config = copy_func(_BaseAutoModelClass.from_config) + from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc) + from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name) + from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example) + from_config.__doc__ = from_config_docstring + from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config) + cls.from_config = classmethod(from_config) + + if name.startswith("TF"): + from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING + elif name.startswith("Flax"): + from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING + else: + from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING + from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained) + from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc) + from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name) + from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example) + shortcut = checkpoint_for_example.split("/")[-1].split("-")[0] + from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut) + from_pretrained.__doc__ = from_pretrained_docstring + from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained) + cls.from_pretrained = classmethod(from_pretrained) + return cls + + +def get_values(model_mapping): + result = [] + for model in model_mapping.values(): + if isinstance(model, (list, tuple)): + result += list(model) + else: + result.append(model) + + return result + + +def getattribute_from_module(module, attr): + if attr is None: + return None + if isinstance(attr, tuple): + return tuple(getattribute_from_module(module, a) for a in attr) + if hasattr(module, attr): + return getattr(module, attr) + # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the + # object at the top level. + transformers_module = importlib.import_module("transformers") + + if module != transformers_module: + try: + return getattribute_from_module(transformers_module, attr) + except ValueError: + raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!") + else: + raise ValueError(f"Could not find {attr} in {transformers_module}!") + + +def add_generation_mixin_to_remote_model(model_class): + """ + Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model. + + This function is used for backwards compatibility purposes: in v4.45, we've started a deprecation cycle to make + `PreTrainedModel` stop inheriting from `GenerationMixin`. Without this function, older models dynamically loaded + from the Hub may not have the `generate` method after we remove the inheritance. + """ + # 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing + if "torch.nn.modules.module.Module" not in str(model_class.__mro__): + return model_class + + # 2. If it already **directly** inherits from GenerationMixin, do nothing + if "GenerationMixin" in str(model_class.__bases__): + return model_class + + # 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or + # `prepare_inputs_for_generation` method. + has_custom_generate_in_class = hasattr(model_class, "generate") and "GenerationMixin" not in str( + getattr(model_class, "generate") + ) + has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str( + getattr(model_class, "prepare_inputs_for_generation") + ) + if has_custom_generate_in_class or has_custom_prepare_inputs: + model_class_with_generation_mixin = type( + model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__} + ) + return model_class_with_generation_mixin + return model_class + + +class _LazyAutoMapping(OrderedDict[type[PretrainedConfig], _LazyAutoMappingValue]): + """ + " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed. + + Args: + - config_mapping: The map model type to config class + - model_mapping: The map model type to model (or tokenizer) class + """ + + def __init__(self, config_mapping, model_mapping) -> None: + self._config_mapping = config_mapping + self._reverse_config_mapping = {v: k for k, v in config_mapping.items()} + self._model_mapping = model_mapping + self._model_mapping._model_mapping = self + self._extra_content = {} + self._modules = {} + + def __len__(self) -> int: + common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys()) + return len(common_keys) + len(self._extra_content) + + def __getitem__(self, key: type[PretrainedConfig]) -> _LazyAutoMappingValue: + if key in self._extra_content: + return self._extra_content[key] + model_type = self._reverse_config_mapping[key.__name__] + if model_type in self._model_mapping: + model_name = self._model_mapping[model_type] + return self._load_attr_from_module(model_type, model_name) + + # Maybe there was several model types associated with this config. + model_types = [k for k, v in self._config_mapping.items() if v == key.__name__] + for mtype in model_types: + if mtype in self._model_mapping: + model_name = self._model_mapping[mtype] + return self._load_attr_from_module(mtype, model_name) + raise KeyError(key) + + def _load_attr_from_module(self, model_type, attr): + module_name = model_type_to_module_name(model_type) + if module_name not in self._modules: + self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models") + return getattribute_from_module(self._modules[module_name], attr) + + def keys(self) -> list[type[PretrainedConfig]]: + mapping_keys = [ + self._load_attr_from_module(key, name) + for key, name in self._config_mapping.items() + if key in self._model_mapping + ] + return mapping_keys + list(self._extra_content.keys()) + + def get(self, key: type[PretrainedConfig], default: _T) -> Union[_LazyAutoMappingValue, _T]: + try: + return self.__getitem__(key) + except KeyError: + return default + + def __bool__(self) -> bool: + return bool(self.keys()) + + def values(self) -> list[_LazyAutoMappingValue]: + mapping_values = [ + self._load_attr_from_module(key, name) + for key, name in self._model_mapping.items() + if key in self._config_mapping + ] + return mapping_values + list(self._extra_content.values()) + + def items(self) -> list[tuple[type[PretrainedConfig], _LazyAutoMappingValue]]: + mapping_items = [ + ( + self._load_attr_from_module(key, self._config_mapping[key]), + self._load_attr_from_module(key, self._model_mapping[key]), + ) + for key in self._model_mapping + if key in self._config_mapping + ] + return mapping_items + list(self._extra_content.items()) + + def __iter__(self) -> Iterator[type[PretrainedConfig]]: + return iter(self.keys()) + + def __contains__(self, item: type) -> bool: + if item in self._extra_content: + return True + if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: + return False + model_type = self._reverse_config_mapping[item.__name__] + return model_type in self._model_mapping + + def register(self, key: type[PretrainedConfig], value: _LazyAutoMappingValue, exist_ok=False) -> None: + """ + Register a new model in this mapping. + """ + if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping: + model_type = self._reverse_config_mapping[key.__name__] + if model_type in self._model_mapping and not exist_ok: + raise ValueError(f"'{key}' is already used by a Transformers model.") + + self._extra_content[key] = value + + +__all__ = ["get_values"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/configuration_auto.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/configuration_auto.py new file mode 100644 index 0000000000000000000000000000000000000000..f6a12e7cef986fb837abbfee0cc81b64b7148b50 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/configuration_auto.py @@ -0,0 +1,1404 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Auto Config class.""" + +import importlib +import os +import re +import warnings +from collections import OrderedDict +from collections.abc import Callable, Iterator, KeysView, ValuesView +from typing import Any, TypeVar, Union + +from ...configuration_utils import PretrainedConfig +from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from ...utils import CONFIG_NAME, logging + + +logger = logging.get_logger(__name__) + + +_CallableT = TypeVar("_CallableT", bound=Callable[..., Any]) + + +CONFIG_MAPPING_NAMES = OrderedDict[str, str]( + [ + # Add configs here + ("aimv2", "Aimv2Config"), + ("aimv2_vision_model", "Aimv2VisionConfig"), + ("albert", "AlbertConfig"), + ("align", "AlignConfig"), + ("altclip", "AltCLIPConfig"), + ("apertus", "ApertusConfig"), + ("arcee", "ArceeConfig"), + ("aria", "AriaConfig"), + ("aria_text", "AriaTextConfig"), + ("audio-spectrogram-transformer", "ASTConfig"), + ("autoformer", "AutoformerConfig"), + ("aya_vision", "AyaVisionConfig"), + ("bamba", "BambaConfig"), + ("bark", "BarkConfig"), + ("bart", "BartConfig"), + ("beit", "BeitConfig"), + ("bert", "BertConfig"), + ("bert-generation", "BertGenerationConfig"), + ("big_bird", "BigBirdConfig"), + ("bigbird_pegasus", "BigBirdPegasusConfig"), + ("biogpt", "BioGptConfig"), + ("bit", "BitConfig"), + ("bitnet", "BitNetConfig"), + ("blenderbot", "BlenderbotConfig"), + ("blenderbot-small", "BlenderbotSmallConfig"), + ("blip", "BlipConfig"), + ("blip-2", "Blip2Config"), + ("blip_2_qformer", "Blip2QFormerConfig"), + ("bloom", "BloomConfig"), + ("blt", "BltConfig"), + ("bridgetower", "BridgeTowerConfig"), + ("bros", "BrosConfig"), + ("camembert", "CamembertConfig"), + ("canine", "CanineConfig"), + ("chameleon", "ChameleonConfig"), + ("chinese_clip", "ChineseCLIPConfig"), + ("chinese_clip_vision_model", "ChineseCLIPVisionConfig"), + ("clap", "ClapConfig"), + ("clip", "CLIPConfig"), + ("clip_text_model", "CLIPTextConfig"), + ("clip_vision_model", "CLIPVisionConfig"), + ("clipseg", "CLIPSegConfig"), + ("clvp", "ClvpConfig"), + ("code_llama", "LlamaConfig"), + ("codegen", "CodeGenConfig"), + ("cohere", "CohereConfig"), + ("cohere2", "Cohere2Config"), + ("cohere2_vision", "Cohere2VisionConfig"), + ("colpali", "ColPaliConfig"), + ("colqwen2", "ColQwen2Config"), + ("conditional_detr", "ConditionalDetrConfig"), + ("convbert", "ConvBertConfig"), + ("convnext", "ConvNextConfig"), + ("convnextv2", "ConvNextV2Config"), + ("cpmant", "CpmAntConfig"), + ("csm", "CsmConfig"), + ("ctrl", "CTRLConfig"), + ("cvt", "CvtConfig"), + ("d_fine", "DFineConfig"), + ("dab-detr", "DabDetrConfig"), + ("dac", "DacConfig"), + ("data2vec-audio", "Data2VecAudioConfig"), + ("data2vec-text", "Data2VecTextConfig"), + ("data2vec-vision", "Data2VecVisionConfig"), + ("dbrx", "DbrxConfig"), + ("deberta", "DebertaConfig"), + ("deberta-v2", "DebertaV2Config"), + ("decision_transformer", "DecisionTransformerConfig"), + ("deepseek_v2", "DeepseekV2Config"), + ("deepseek_v3", "DeepseekV3Config"), + ("deepseek_vl", "DeepseekVLConfig"), + ("deepseek_vl_hybrid", "DeepseekVLHybridConfig"), + ("deformable_detr", "DeformableDetrConfig"), + ("deit", "DeiTConfig"), + ("depth_anything", "DepthAnythingConfig"), + ("depth_pro", "DepthProConfig"), + ("deta", "DetaConfig"), + ("detr", "DetrConfig"), + ("dia", "DiaConfig"), + ("diffllama", "DiffLlamaConfig"), + ("dinat", "DinatConfig"), + ("dinov2", "Dinov2Config"), + ("dinov2_with_registers", "Dinov2WithRegistersConfig"), + ("dinov3_convnext", "DINOv3ConvNextConfig"), + ("dinov3_vit", "DINOv3ViTConfig"), + ("distilbert", "DistilBertConfig"), + ("doge", "DogeConfig"), + ("donut-swin", "DonutSwinConfig"), + ("dots1", "Dots1Config"), + ("dpr", "DPRConfig"), + ("dpt", "DPTConfig"), + ("edgetam", "EdgeTamConfig"), + ("edgetam_video", "EdgeTamVideoConfig"), + ("edgetam_vision_model", "EdgeTamVisionConfig"), + ("efficientformer", "EfficientFormerConfig"), + ("efficientloftr", "EfficientLoFTRConfig"), + ("efficientnet", "EfficientNetConfig"), + ("electra", "ElectraConfig"), + ("emu3", "Emu3Config"), + ("encodec", "EncodecConfig"), + ("encoder-decoder", "EncoderDecoderConfig"), + ("eomt", "EomtConfig"), + ("ernie", "ErnieConfig"), + ("ernie4_5", "Ernie4_5Config"), + ("ernie4_5_moe", "Ernie4_5_MoeConfig"), + ("ernie_m", "ErnieMConfig"), + ("esm", "EsmConfig"), + ("evolla", "EvollaConfig"), + ("exaone4", "Exaone4Config"), + ("falcon", "FalconConfig"), + ("falcon_h1", "FalconH1Config"), + ("falcon_mamba", "FalconMambaConfig"), + ("fastspeech2_conformer", "FastSpeech2ConformerConfig"), + ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGanConfig"), + ("flaubert", "FlaubertConfig"), + ("flava", "FlavaConfig"), + ("flex_olmo", "FlexOlmoConfig"), + ("florence2", "Florence2Config"), + ("fnet", "FNetConfig"), + ("focalnet", "FocalNetConfig"), + ("fsmt", "FSMTConfig"), + ("funnel", "FunnelConfig"), + ("fuyu", "FuyuConfig"), + ("gemma", "GemmaConfig"), + ("gemma2", "Gemma2Config"), + ("gemma3", "Gemma3Config"), + ("gemma3_text", "Gemma3TextConfig"), + ("gemma3n", "Gemma3nConfig"), + ("gemma3n_audio", "Gemma3nAudioConfig"), + ("gemma3n_text", "Gemma3nTextConfig"), + ("gemma3n_vision", "Gemma3nVisionConfig"), + ("git", "GitConfig"), + ("glm", "GlmConfig"), + ("glm4", "Glm4Config"), + ("glm4_moe", "Glm4MoeConfig"), + ("glm4v", "Glm4vConfig"), + ("glm4v_moe", "Glm4vMoeConfig"), + ("glm4v_moe_text", "Glm4vMoeTextConfig"), + ("glm4v_text", "Glm4vTextConfig"), + ("glpn", "GLPNConfig"), + ("got_ocr2", "GotOcr2Config"), + ("gpt-sw3", "GPT2Config"), + ("gpt2", "GPT2Config"), + ("gpt_bigcode", "GPTBigCodeConfig"), + ("gpt_neo", "GPTNeoConfig"), + ("gpt_neox", "GPTNeoXConfig"), + ("gpt_neox_japanese", "GPTNeoXJapaneseConfig"), + ("gpt_oss", "GptOssConfig"), + ("gptj", "GPTJConfig"), + ("gptsan-japanese", "GPTSanJapaneseConfig"), + ("granite", "GraniteConfig"), + ("granite_speech", "GraniteSpeechConfig"), + ("granitemoe", "GraniteMoeConfig"), + ("granitemoehybrid", "GraniteMoeHybridConfig"), + ("granitemoeshared", "GraniteMoeSharedConfig"), + ("granitevision", "LlavaNextConfig"), + ("graphormer", "GraphormerConfig"), + ("grounding-dino", "GroundingDinoConfig"), + ("groupvit", "GroupViTConfig"), + ("helium", "HeliumConfig"), + ("hgnet_v2", "HGNetV2Config"), + ("hiera", "HieraConfig"), + ("hubert", "HubertConfig"), + ("hunyuan_v1_dense", "HunYuanDenseV1Config"), + ("hunyuan_v1_moe", "HunYuanMoEV1Config"), + ("ibert", "IBertConfig"), + ("idefics", "IdeficsConfig"), + ("idefics2", "Idefics2Config"), + ("idefics3", "Idefics3Config"), + ("idefics3_vision", "Idefics3VisionConfig"), + ("ijepa", "IJepaConfig"), + ("imagegpt", "ImageGPTConfig"), + ("informer", "InformerConfig"), + ("instructblip", "InstructBlipConfig"), + ("instructblipvideo", "InstructBlipVideoConfig"), + ("internvl", "InternVLConfig"), + ("internvl_vision", "InternVLVisionConfig"), + ("jamba", "JambaConfig"), + ("janus", "JanusConfig"), + ("jetmoe", "JetMoeConfig"), + ("jukebox", "JukeboxConfig"), + ("kosmos-2", "Kosmos2Config"), + ("kosmos-2.5", "Kosmos2_5Config"), + ("kyutai_speech_to_text", "KyutaiSpeechToTextConfig"), + ("layoutlm", "LayoutLMConfig"), + ("layoutlmv2", "LayoutLMv2Config"), + ("layoutlmv3", "LayoutLMv3Config"), + ("led", "LEDConfig"), + ("levit", "LevitConfig"), + ("lfm2", "Lfm2Config"), + ("lfm2_vl", "Lfm2VlConfig"), + ("lightglue", "LightGlueConfig"), + ("lilt", "LiltConfig"), + ("llama", "LlamaConfig"), + ("llama4", "Llama4Config"), + ("llama4_text", "Llama4TextConfig"), + ("llava", "LlavaConfig"), + ("llava_next", "LlavaNextConfig"), + ("llava_next_video", "LlavaNextVideoConfig"), + ("llava_onevision", "LlavaOnevisionConfig"), + ("longcat_flash", "LongcatFlashConfig"), + ("longformer", "LongformerConfig"), + ("longt5", "LongT5Config"), + ("luke", "LukeConfig"), + ("lxmert", "LxmertConfig"), + ("m2m_100", "M2M100Config"), + ("mamba", "MambaConfig"), + ("mamba2", "Mamba2Config"), + ("marian", "MarianConfig"), + ("markuplm", "MarkupLMConfig"), + ("mask2former", "Mask2FormerConfig"), + ("maskformer", "MaskFormerConfig"), + ("maskformer-swin", "MaskFormerSwinConfig"), + ("mbart", "MBartConfig"), + ("mctct", "MCTCTConfig"), + ("mega", "MegaConfig"), + ("megatron-bert", "MegatronBertConfig"), + ("metaclip_2", "MetaClip2Config"), + ("mgp-str", "MgpstrConfig"), + ("mimi", "MimiConfig"), + ("minimax", "MiniMaxConfig"), + ("ministral", "MinistralConfig"), + ("mistral", "MistralConfig"), + ("mistral3", "Mistral3Config"), + ("mixtral", "MixtralConfig"), + ("mlcd", "MLCDVisionConfig"), + ("mllama", "MllamaConfig"), + ("mm-grounding-dino", "MMGroundingDinoConfig"), + ("mobilebert", "MobileBertConfig"), + ("mobilenet_v1", "MobileNetV1Config"), + ("mobilenet_v2", "MobileNetV2Config"), + ("mobilevit", "MobileViTConfig"), + ("mobilevitv2", "MobileViTV2Config"), + ("modernbert", "ModernBertConfig"), + ("modernbert-decoder", "ModernBertDecoderConfig"), + ("moonshine", "MoonshineConfig"), + ("moshi", "MoshiConfig"), + ("mpnet", "MPNetConfig"), + ("mpt", "MptConfig"), + ("mra", "MraConfig"), + ("mt5", "MT5Config"), + ("musicgen", "MusicgenConfig"), + ("musicgen_melody", "MusicgenMelodyConfig"), + ("mvp", "MvpConfig"), + ("nat", "NatConfig"), + ("nemotron", "NemotronConfig"), + ("nezha", "NezhaConfig"), + ("nllb-moe", "NllbMoeConfig"), + ("nougat", "VisionEncoderDecoderConfig"), + ("nystromformer", "NystromformerConfig"), + ("olmo", "OlmoConfig"), + ("olmo2", "Olmo2Config"), + ("olmo3", "Olmo3Config"), + ("olmoe", "OlmoeConfig"), + ("omdet-turbo", "OmDetTurboConfig"), + ("oneformer", "OneFormerConfig"), + ("open-llama", "OpenLlamaConfig"), + ("openai-gpt", "OpenAIGPTConfig"), + ("opt", "OPTConfig"), + ("ovis2", "Ovis2Config"), + ("owlv2", "Owlv2Config"), + ("owlvit", "OwlViTConfig"), + ("paligemma", "PaliGemmaConfig"), + ("parakeet_ctc", "ParakeetCTCConfig"), + ("parakeet_encoder", "ParakeetEncoderConfig"), + ("patchtsmixer", "PatchTSMixerConfig"), + ("patchtst", "PatchTSTConfig"), + ("pegasus", "PegasusConfig"), + ("pegasus_x", "PegasusXConfig"), + ("perceiver", "PerceiverConfig"), + ("perception_encoder", "TimmWrapperConfig"), + ("perception_lm", "PerceptionLMConfig"), + ("persimmon", "PersimmonConfig"), + ("phi", "PhiConfig"), + ("phi3", "Phi3Config"), + ("phi4_multimodal", "Phi4MultimodalConfig"), + ("phimoe", "PhimoeConfig"), + ("pix2struct", "Pix2StructConfig"), + ("pixtral", "PixtralVisionConfig"), + ("plbart", "PLBartConfig"), + ("poolformer", "PoolFormerConfig"), + ("pop2piano", "Pop2PianoConfig"), + ("prompt_depth_anything", "PromptDepthAnythingConfig"), + ("prophetnet", "ProphetNetConfig"), + ("pvt", "PvtConfig"), + ("pvt_v2", "PvtV2Config"), + ("qdqbert", "QDQBertConfig"), + ("qwen2", "Qwen2Config"), + ("qwen2_5_omni", "Qwen2_5OmniConfig"), + ("qwen2_5_vl", "Qwen2_5_VLConfig"), + ("qwen2_5_vl_text", "Qwen2_5_VLTextConfig"), + ("qwen2_audio", "Qwen2AudioConfig"), + ("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"), + ("qwen2_moe", "Qwen2MoeConfig"), + ("qwen2_vl", "Qwen2VLConfig"), + ("qwen2_vl_text", "Qwen2VLTextConfig"), + ("qwen3", "Qwen3Config"), + ("qwen3_moe", "Qwen3MoeConfig"), + ("qwen3_next", "Qwen3NextConfig"), + ("qwen3_omni_moe", "Qwen3OmniMoeConfig"), + ("qwen3_vl", "Qwen3VLConfig"), + ("qwen3_vl_moe", "Qwen3VLMoeConfig"), + ("qwen3_vl_moe_text", "Qwen3VLMoeTextConfig"), + ("qwen3_vl_text", "Qwen3VLTextConfig"), + ("rag", "RagConfig"), + ("realm", "RealmConfig"), + ("recurrent_gemma", "RecurrentGemmaConfig"), + ("reformer", "ReformerConfig"), + ("regnet", "RegNetConfig"), + ("rembert", "RemBertConfig"), + ("resnet", "ResNetConfig"), + ("retribert", "RetriBertConfig"), + ("roberta", "RobertaConfig"), + ("roberta-prelayernorm", "RobertaPreLayerNormConfig"), + ("roc_bert", "RoCBertConfig"), + ("roformer", "RoFormerConfig"), + ("rt_detr", "RTDetrConfig"), + ("rt_detr_resnet", "RTDetrResNetConfig"), + ("rt_detr_v2", "RTDetrV2Config"), + ("rwkv", "RwkvConfig"), + ("sam", "SamConfig"), + ("sam2", "Sam2Config"), + ("sam2_hiera_det_model", "Sam2HieraDetConfig"), + ("sam2_video", "Sam2VideoConfig"), + ("sam2_vision_model", "Sam2VisionConfig"), + ("sam_hq", "SamHQConfig"), + ("sam_hq_vision_model", "SamHQVisionConfig"), + ("sam_vision_model", "SamVisionConfig"), + ("seamless_m4t", "SeamlessM4TConfig"), + ("seamless_m4t_v2", "SeamlessM4Tv2Config"), + ("seed_oss", "SeedOssConfig"), + ("segformer", "SegformerConfig"), + ("seggpt", "SegGptConfig"), + ("sew", "SEWConfig"), + ("sew-d", "SEWDConfig"), + ("shieldgemma2", "ShieldGemma2Config"), + ("siglip", "SiglipConfig"), + ("siglip2", "Siglip2Config"), + ("siglip2_vision_model", "Siglip2VisionConfig"), + ("siglip_vision_model", "SiglipVisionConfig"), + ("smollm3", "SmolLM3Config"), + ("smolvlm", "SmolVLMConfig"), + ("smolvlm_vision", "SmolVLMVisionConfig"), + ("speech-encoder-decoder", "SpeechEncoderDecoderConfig"), + ("speech_to_text", "Speech2TextConfig"), + ("speech_to_text_2", "Speech2Text2Config"), + ("speecht5", "SpeechT5Config"), + ("splinter", "SplinterConfig"), + ("squeezebert", "SqueezeBertConfig"), + ("stablelm", "StableLmConfig"), + ("starcoder2", "Starcoder2Config"), + ("superglue", "SuperGlueConfig"), + ("superpoint", "SuperPointConfig"), + ("swiftformer", "SwiftFormerConfig"), + ("swin", "SwinConfig"), + ("swin2sr", "Swin2SRConfig"), + ("swinv2", "Swinv2Config"), + ("switch_transformers", "SwitchTransformersConfig"), + ("t5", "T5Config"), + ("t5gemma", "T5GemmaConfig"), + ("table-transformer", "TableTransformerConfig"), + ("tapas", "TapasConfig"), + ("textnet", "TextNetConfig"), + ("time_series_transformer", "TimeSeriesTransformerConfig"), + ("timesfm", "TimesFmConfig"), + ("timesformer", "TimesformerConfig"), + ("timm_backbone", "TimmBackboneConfig"), + ("timm_wrapper", "TimmWrapperConfig"), + ("trajectory_transformer", "TrajectoryTransformerConfig"), + ("transfo-xl", "TransfoXLConfig"), + ("trocr", "TrOCRConfig"), + ("tvlt", "TvltConfig"), + ("tvp", "TvpConfig"), + ("udop", "UdopConfig"), + ("umt5", "UMT5Config"), + ("unispeech", "UniSpeechConfig"), + ("unispeech-sat", "UniSpeechSatConfig"), + ("univnet", "UnivNetConfig"), + ("upernet", "UperNetConfig"), + ("van", "VanConfig"), + ("vaultgemma", "VaultGemmaConfig"), + ("video_llava", "VideoLlavaConfig"), + ("videomae", "VideoMAEConfig"), + ("vilt", "ViltConfig"), + ("vipllava", "VipLlavaConfig"), + ("vision-encoder-decoder", "VisionEncoderDecoderConfig"), + ("vision-text-dual-encoder", "VisionTextDualEncoderConfig"), + ("visual_bert", "VisualBertConfig"), + ("vit", "ViTConfig"), + ("vit_hybrid", "ViTHybridConfig"), + ("vit_mae", "ViTMAEConfig"), + ("vit_msn", "ViTMSNConfig"), + ("vitdet", "VitDetConfig"), + ("vitmatte", "VitMatteConfig"), + ("vitpose", "VitPoseConfig"), + ("vitpose_backbone", "VitPoseBackboneConfig"), + ("vits", "VitsConfig"), + ("vivit", "VivitConfig"), + ("vjepa2", "VJEPA2Config"), + ("voxtral", "VoxtralConfig"), + ("voxtral_encoder", "VoxtralEncoderConfig"), + ("wav2vec2", "Wav2Vec2Config"), + ("wav2vec2-bert", "Wav2Vec2BertConfig"), + ("wav2vec2-conformer", "Wav2Vec2ConformerConfig"), + ("wavlm", "WavLMConfig"), + ("whisper", "WhisperConfig"), + ("xclip", "XCLIPConfig"), + ("xcodec", "XcodecConfig"), + ("xglm", "XGLMConfig"), + ("xlm", "XLMConfig"), + ("xlm-prophetnet", "XLMProphetNetConfig"), + ("xlm-roberta", "XLMRobertaConfig"), + ("xlm-roberta-xl", "XLMRobertaXLConfig"), + ("xlnet", "XLNetConfig"), + ("xlstm", "xLSTMConfig"), + ("xmod", "XmodConfig"), + ("yolos", "YolosConfig"), + ("yoso", "YosoConfig"), + ("zamba", "ZambaConfig"), + ("zamba2", "Zamba2Config"), + ("zoedepth", "ZoeDepthConfig"), + ] +) + + +MODEL_NAMES_MAPPING = OrderedDict[str, str]( + [ + # Add full (and cased) model names here + ("aimv2", "AIMv2"), + ("aimv2_vision_model", "Aimv2VisionModel"), + ("albert", "ALBERT"), + ("align", "ALIGN"), + ("altclip", "AltCLIP"), + ("apertus", "Apertus"), + ("arcee", "Arcee"), + ("aria", "Aria"), + ("aria_text", "AriaText"), + ("audio-spectrogram-transformer", "Audio Spectrogram Transformer"), + ("autoformer", "Autoformer"), + ("aya_vision", "AyaVision"), + ("bamba", "Bamba"), + ("bark", "Bark"), + ("bart", "BART"), + ("barthez", "BARThez"), + ("bartpho", "BARTpho"), + ("beit", "BEiT"), + ("bert", "BERT"), + ("bert-generation", "Bert Generation"), + ("bert-japanese", "BertJapanese"), + ("bertweet", "BERTweet"), + ("big_bird", "BigBird"), + ("bigbird_pegasus", "BigBird-Pegasus"), + ("biogpt", "BioGpt"), + ("bit", "BiT"), + ("bitnet", "BitNet"), + ("blenderbot", "Blenderbot"), + ("blenderbot-small", "BlenderbotSmall"), + ("blip", "BLIP"), + ("blip-2", "BLIP-2"), + ("blip_2_qformer", "BLIP-2 QFormer"), + ("bloom", "BLOOM"), + ("blt", "Blt"), + ("bort", "BORT"), + ("bridgetower", "BridgeTower"), + ("bros", "BROS"), + ("byt5", "ByT5"), + ("camembert", "CamemBERT"), + ("canine", "CANINE"), + ("chameleon", "Chameleon"), + ("chinese_clip", "Chinese-CLIP"), + ("chinese_clip_vision_model", "ChineseCLIPVisionModel"), + ("clap", "CLAP"), + ("clip", "CLIP"), + ("clip_text_model", "CLIPTextModel"), + ("clip_vision_model", "CLIPVisionModel"), + ("clipseg", "CLIPSeg"), + ("clvp", "CLVP"), + ("code_llama", "CodeLlama"), + ("codegen", "CodeGen"), + ("cohere", "Cohere"), + ("cohere2", "Cohere2"), + ("cohere2_vision", "Cohere2Vision"), + ("colpali", "ColPali"), + ("colqwen2", "ColQwen2"), + ("conditional_detr", "Conditional DETR"), + ("convbert", "ConvBERT"), + ("convnext", "ConvNeXT"), + ("convnextv2", "ConvNeXTV2"), + ("cpm", "CPM"), + ("cpmant", "CPM-Ant"), + ("csm", "CSM"), + ("ctrl", "CTRL"), + ("cvt", "CvT"), + ("d_fine", "D-FINE"), + ("dab-detr", "DAB-DETR"), + ("dac", "DAC"), + ("data2vec-audio", "Data2VecAudio"), + ("data2vec-text", "Data2VecText"), + ("data2vec-vision", "Data2VecVision"), + ("dbrx", "DBRX"), + ("deberta", "DeBERTa"), + ("deberta-v2", "DeBERTa-v2"), + ("decision_transformer", "Decision Transformer"), + ("deepseek_v2", "DeepSeek-V2"), + ("deepseek_v3", "DeepSeek-V3"), + ("deepseek_vl", "DeepseekVL"), + ("deepseek_vl_hybrid", "DeepseekVLHybrid"), + ("deformable_detr", "Deformable DETR"), + ("deit", "DeiT"), + ("deplot", "DePlot"), + ("depth_anything", "Depth Anything"), + ("depth_anything_v2", "Depth Anything V2"), + ("depth_pro", "DepthPro"), + ("deta", "DETA"), + ("detr", "DETR"), + ("dia", "Dia"), + ("dialogpt", "DialoGPT"), + ("diffllama", "DiffLlama"), + ("dinat", "DiNAT"), + ("dinov2", "DINOv2"), + ("dinov2_with_registers", "DINOv2 with Registers"), + ("dinov3_convnext", "DINOv3 ConvNext"), + ("dinov3_vit", "DINOv3 ViT"), + ("distilbert", "DistilBERT"), + ("dit", "DiT"), + ("doge", "Doge"), + ("donut-swin", "DonutSwin"), + ("dots1", "dots1"), + ("dpr", "DPR"), + ("dpt", "DPT"), + ("edgetam", "EdgeTAM"), + ("edgetam_video", "EdgeTamVideo"), + ("edgetam_vision_model", "EdgeTamVisionModel"), + ("efficientformer", "EfficientFormer"), + ("efficientloftr", "EfficientLoFTR"), + ("efficientnet", "EfficientNet"), + ("electra", "ELECTRA"), + ("emu3", "Emu3"), + ("encodec", "EnCodec"), + ("encoder-decoder", "Encoder decoder"), + ("eomt", "EoMT"), + ("ernie", "ERNIE"), + ("ernie4_5", "Ernie4_5"), + ("ernie4_5_moe", "Ernie4_5_MoE"), + ("ernie_m", "ErnieM"), + ("esm", "ESM"), + ("evolla", "Evolla"), + ("exaone4", "EXAONE-4.0"), + ("falcon", "Falcon"), + ("falcon3", "Falcon3"), + ("falcon_h1", "FalconH1"), + ("falcon_mamba", "FalconMamba"), + ("fastspeech2_conformer", "FastSpeech2Conformer"), + ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"), + ("flan-t5", "FLAN-T5"), + ("flan-ul2", "FLAN-UL2"), + ("flaubert", "FlauBERT"), + ("flava", "FLAVA"), + ("flex_olmo", "FlexOlmo"), + ("florence2", "Florence2"), + ("fnet", "FNet"), + ("focalnet", "FocalNet"), + ("fsmt", "FairSeq Machine-Translation"), + ("funnel", "Funnel Transformer"), + ("fuyu", "Fuyu"), + ("gemma", "Gemma"), + ("gemma2", "Gemma2"), + ("gemma3", "Gemma3ForConditionalGeneration"), + ("gemma3_text", "Gemma3ForCausalLM"), + ("gemma3n", "Gemma3nForConditionalGeneration"), + ("gemma3n_audio", "Gemma3nAudioEncoder"), + ("gemma3n_text", "Gemma3nForCausalLM"), + ("gemma3n_vision", "TimmWrapperModel"), + ("git", "GIT"), + ("glm", "GLM"), + ("glm4", "GLM4"), + ("glm4_moe", "Glm4MoE"), + ("glm4v", "GLM4V"), + ("glm4v_moe", "GLM4VMOE"), + ("glm4v_moe_text", "GLM4VMOE"), + ("glm4v_text", "GLM4V"), + ("glpn", "GLPN"), + ("got_ocr2", "GOT-OCR2"), + ("gpt-sw3", "GPT-Sw3"), + ("gpt2", "OpenAI GPT-2"), + ("gpt_bigcode", "GPTBigCode"), + ("gpt_neo", "GPT Neo"), + ("gpt_neox", "GPT NeoX"), + ("gpt_neox_japanese", "GPT NeoX Japanese"), + ("gpt_oss", "GptOss"), + ("gptj", "GPT-J"), + ("gptsan-japanese", "GPTSAN-japanese"), + ("granite", "Granite"), + ("granite_speech", "GraniteSpeech"), + ("granitemoe", "GraniteMoeMoe"), + ("granitemoehybrid", "GraniteMoeHybrid"), + ("granitemoeshared", "GraniteMoeSharedMoe"), + ("granitevision", "LLaVA-NeXT"), + ("graphormer", "Graphormer"), + ("grounding-dino", "Grounding DINO"), + ("groupvit", "GroupViT"), + ("helium", "Helium"), + ("herbert", "HerBERT"), + ("hgnet_v2", "HGNet-V2"), + ("hiera", "Hiera"), + ("hubert", "Hubert"), + ("hunyuan_v1_dense", "HunYuanDenseV1"), + ("hunyuan_v1_moe", "HunYuanMoeV1"), + ("ibert", "I-BERT"), + ("idefics", "IDEFICS"), + ("idefics2", "Idefics2"), + ("idefics3", "Idefics3"), + ("idefics3_vision", "Idefics3VisionTransformer"), + ("ijepa", "I-JEPA"), + ("imagegpt", "ImageGPT"), + ("informer", "Informer"), + ("instructblip", "InstructBLIP"), + ("instructblipvideo", "InstructBlipVideo"), + ("internvl", "InternVL"), + ("internvl_vision", "InternVLVision"), + ("jamba", "Jamba"), + ("janus", "Janus"), + ("jetmoe", "JetMoe"), + ("jukebox", "Jukebox"), + ("kosmos-2", "KOSMOS-2"), + ("kosmos-2.5", "KOSMOS-2.5"), + ("kyutai_speech_to_text", "KyutaiSpeechToText"), + ("layoutlm", "LayoutLM"), + ("layoutlmv2", "LayoutLMv2"), + ("layoutlmv3", "LayoutLMv3"), + ("layoutxlm", "LayoutXLM"), + ("led", "LED"), + ("levit", "LeViT"), + ("lfm2", "Lfm2"), + ("lfm2_vl", "Lfm2Vl"), + ("lightglue", "LightGlue"), + ("lilt", "LiLT"), + ("llama", "LLaMA"), + ("llama2", "Llama2"), + ("llama3", "Llama3"), + ("llama4", "Llama4"), + ("llama4_text", "Llama4ForCausalLM"), + ("llava", "LLaVa"), + ("llava_next", "LLaVA-NeXT"), + ("llava_next_video", "LLaVa-NeXT-Video"), + ("llava_onevision", "LLaVA-Onevision"), + ("longcat_flash", "LongCatFlash"), + ("longformer", "Longformer"), + ("longt5", "LongT5"), + ("luke", "LUKE"), + ("lxmert", "LXMERT"), + ("m2m_100", "M2M100"), + ("madlad-400", "MADLAD-400"), + ("mamba", "Mamba"), + ("mamba2", "mamba2"), + ("marian", "Marian"), + ("markuplm", "MarkupLM"), + ("mask2former", "Mask2Former"), + ("maskformer", "MaskFormer"), + ("maskformer-swin", "MaskFormerSwin"), + ("matcha", "MatCha"), + ("mbart", "mBART"), + ("mbart50", "mBART-50"), + ("mctct", "M-CTC-T"), + ("mega", "MEGA"), + ("megatron-bert", "Megatron-BERT"), + ("megatron_gpt2", "Megatron-GPT2"), + ("metaclip_2", "MetaCLIP 2"), + ("mgp-str", "MGP-STR"), + ("mimi", "Mimi"), + ("minimax", "MiniMax"), + ("ministral", "Ministral"), + ("mistral", "Mistral"), + ("mistral3", "Mistral3"), + ("mixtral", "Mixtral"), + ("mlcd", "MLCD"), + ("mllama", "Mllama"), + ("mluke", "mLUKE"), + ("mm-grounding-dino", "MM Grounding DINO"), + ("mms", "MMS"), + ("mobilebert", "MobileBERT"), + ("mobilenet_v1", "MobileNetV1"), + ("mobilenet_v2", "MobileNetV2"), + ("mobilevit", "MobileViT"), + ("mobilevitv2", "MobileViTV2"), + ("modernbert", "ModernBERT"), + ("modernbert-decoder", "ModernBertDecoder"), + ("moonshine", "Moonshine"), + ("moshi", "Moshi"), + ("mpnet", "MPNet"), + ("mpt", "MPT"), + ("mra", "MRA"), + ("mt5", "MT5"), + ("musicgen", "MusicGen"), + ("musicgen_melody", "MusicGen Melody"), + ("mvp", "MVP"), + ("myt5", "myt5"), + ("nat", "NAT"), + ("nemotron", "Nemotron"), + ("nezha", "Nezha"), + ("nllb", "NLLB"), + ("nllb-moe", "NLLB-MOE"), + ("nougat", "Nougat"), + ("nystromformer", "Nyströmformer"), + ("olmo", "OLMo"), + ("olmo2", "OLMo2"), + ("olmo3", "Olmo3"), + ("olmoe", "OLMoE"), + ("omdet-turbo", "OmDet-Turbo"), + ("oneformer", "OneFormer"), + ("open-llama", "OpenLlama"), + ("openai-gpt", "OpenAI GPT"), + ("opt", "OPT"), + ("ovis2", "Ovis2"), + ("owlv2", "OWLv2"), + ("owlvit", "OWL-ViT"), + ("paligemma", "PaliGemma"), + ("parakeet", "Parakeet"), + ("parakeet_ctc", "Parakeet"), + ("parakeet_encoder", "ParakeetEncoder"), + ("patchtsmixer", "PatchTSMixer"), + ("patchtst", "PatchTST"), + ("pegasus", "Pegasus"), + ("pegasus_x", "PEGASUS-X"), + ("perceiver", "Perceiver"), + ("perception_encoder", "PerceptionEncoder"), + ("perception_lm", "PerceptionLM"), + ("persimmon", "Persimmon"), + ("phi", "Phi"), + ("phi3", "Phi3"), + ("phi4_multimodal", "Phi4Multimodal"), + ("phimoe", "Phimoe"), + ("phobert", "PhoBERT"), + ("pix2struct", "Pix2Struct"), + ("pixtral", "Pixtral"), + ("plbart", "PLBart"), + ("poolformer", "PoolFormer"), + ("pop2piano", "Pop2Piano"), + ("prompt_depth_anything", "PromptDepthAnything"), + ("prophetnet", "ProphetNet"), + ("pvt", "PVT"), + ("pvt_v2", "PVTv2"), + ("qdqbert", "QDQBert"), + ("qwen2", "Qwen2"), + ("qwen2_5_omni", "Qwen2_5Omni"), + ("qwen2_5_vl", "Qwen2_5_VL"), + ("qwen2_5_vl_text", "Qwen2_5_VL"), + ("qwen2_audio", "Qwen2Audio"), + ("qwen2_audio_encoder", "Qwen2AudioEncoder"), + ("qwen2_moe", "Qwen2MoE"), + ("qwen2_vl", "Qwen2VL"), + ("qwen2_vl_text", "Qwen2VL"), + ("qwen3", "Qwen3"), + ("qwen3_moe", "Qwen3MoE"), + ("qwen3_next", "Qwen3Next"), + ("qwen3_omni_moe", "Qwen3OmniMoE"), + ("qwen3_vl", "Qwen3VL"), + ("qwen3_vl_moe", "Qwen3VLMoe"), + ("qwen3_vl_moe_text", "Qwen3VLMoe"), + ("qwen3_vl_text", "Qwen3VL"), + ("rag", "RAG"), + ("realm", "REALM"), + ("recurrent_gemma", "RecurrentGemma"), + ("reformer", "Reformer"), + ("regnet", "RegNet"), + ("rembert", "RemBERT"), + ("resnet", "ResNet"), + ("retribert", "RetriBERT"), + ("roberta", "RoBERTa"), + ("roberta-prelayernorm", "RoBERTa-PreLayerNorm"), + ("roc_bert", "RoCBert"), + ("roformer", "RoFormer"), + ("rt_detr", "RT-DETR"), + ("rt_detr_resnet", "RT-DETR-ResNet"), + ("rt_detr_v2", "RT-DETRv2"), + ("rwkv", "RWKV"), + ("sam", "SAM"), + ("sam2", "SAM2"), + ("sam2_hiera_det_model", "Sam2HieraDetModel"), + ("sam2_video", "Sam2VideoModel"), + ("sam2_vision_model", "Sam2VisionModel"), + ("sam_hq", "SAM-HQ"), + ("sam_hq_vision_model", "SamHQVisionModel"), + ("sam_vision_model", "SamVisionModel"), + ("seamless_m4t", "SeamlessM4T"), + ("seamless_m4t_v2", "SeamlessM4Tv2"), + ("seed_oss", "SeedOss"), + ("segformer", "SegFormer"), + ("seggpt", "SegGPT"), + ("sew", "SEW"), + ("sew-d", "SEW-D"), + ("shieldgemma2", "Shieldgemma2"), + ("siglip", "SigLIP"), + ("siglip2", "SigLIP2"), + ("siglip2_vision_model", "Siglip2VisionModel"), + ("siglip_vision_model", "SiglipVisionModel"), + ("smollm3", "SmolLM3"), + ("smolvlm", "SmolVLM"), + ("smolvlm_vision", "SmolVLMVisionTransformer"), + ("speech-encoder-decoder", "Speech Encoder decoder"), + ("speech_to_text", "Speech2Text"), + ("speech_to_text_2", "Speech2Text2"), + ("speecht5", "SpeechT5"), + ("splinter", "Splinter"), + ("squeezebert", "SqueezeBERT"), + ("stablelm", "StableLm"), + ("starcoder2", "Starcoder2"), + ("superglue", "SuperGlue"), + ("superpoint", "SuperPoint"), + ("swiftformer", "SwiftFormer"), + ("swin", "Swin Transformer"), + ("swin2sr", "Swin2SR"), + ("swinv2", "Swin Transformer V2"), + ("switch_transformers", "SwitchTransformers"), + ("t5", "T5"), + ("t5gemma", "T5Gemma"), + ("t5v1.1", "T5v1.1"), + ("table-transformer", "Table Transformer"), + ("tapas", "TAPAS"), + ("tapex", "TAPEX"), + ("textnet", "TextNet"), + ("time_series_transformer", "Time Series Transformer"), + ("timesfm", "TimesFm"), + ("timesformer", "TimeSformer"), + ("timm_backbone", "TimmBackbone"), + ("timm_wrapper", "TimmWrapperModel"), + ("trajectory_transformer", "Trajectory Transformer"), + ("transfo-xl", "Transformer-XL"), + ("trocr", "TrOCR"), + ("tvlt", "TVLT"), + ("tvp", "TVP"), + ("udop", "UDOP"), + ("ul2", "UL2"), + ("umt5", "UMT5"), + ("unispeech", "UniSpeech"), + ("unispeech-sat", "UniSpeechSat"), + ("univnet", "UnivNet"), + ("upernet", "UPerNet"), + ("van", "VAN"), + ("vaultgemma", "VaultGemma"), + ("video_llava", "VideoLlava"), + ("videomae", "VideoMAE"), + ("vilt", "ViLT"), + ("vipllava", "VipLlava"), + ("vision-encoder-decoder", "Vision Encoder decoder"), + ("vision-text-dual-encoder", "VisionTextDualEncoder"), + ("visual_bert", "VisualBERT"), + ("vit", "ViT"), + ("vit_hybrid", "ViT Hybrid"), + ("vit_mae", "ViTMAE"), + ("vit_msn", "ViTMSN"), + ("vitdet", "VitDet"), + ("vitmatte", "ViTMatte"), + ("vitpose", "ViTPose"), + ("vitpose_backbone", "ViTPoseBackbone"), + ("vits", "VITS"), + ("vivit", "ViViT"), + ("vjepa2", "VJEPA2Model"), + ("voxtral", "Voxtral"), + ("voxtral_encoder", "Voxtral Encoder"), + ("wav2vec2", "Wav2Vec2"), + ("wav2vec2-bert", "Wav2Vec2-BERT"), + ("wav2vec2-conformer", "Wav2Vec2-Conformer"), + ("wav2vec2_phoneme", "Wav2Vec2Phoneme"), + ("wavlm", "WavLM"), + ("whisper", "Whisper"), + ("xclip", "X-CLIP"), + ("xcodec", "X-CODEC"), + ("xglm", "XGLM"), + ("xlm", "XLM"), + ("xlm-prophetnet", "XLM-ProphetNet"), + ("xlm-roberta", "XLM-RoBERTa"), + ("xlm-roberta-xl", "XLM-RoBERTa-XL"), + ("xlm-v", "XLM-V"), + ("xlnet", "XLNet"), + ("xls_r", "XLS-R"), + ("xlsr_wav2vec2", "XLSR-Wav2Vec2"), + ("xlstm", "xLSTM"), + ("xmod", "X-MOD"), + ("yolos", "YOLOS"), + ("yoso", "YOSO"), + ("zamba", "Zamba"), + ("zamba2", "Zamba2"), + ("zoedepth", "ZoeDepth"), + ] +) + +# This is tied to the processing `-` -> `_` in `model_type_to_module_name`. For example, instead of putting +# `transfo-xl` (as in `CONFIG_MAPPING_NAMES`), we should use `transfo_xl`. +DEPRECATED_MODELS = [ + "bort", + "deta", + "efficientformer", + "ernie_m", + "gptsan_japanese", + "graphormer", + "jukebox", + "mctct", + "mega", + "mmbt", + "nat", + "nezha", + "open_llama", + "qdqbert", + "realm", + "retribert", + "speech_to_text_2", + "tapex", + "trajectory_transformer", + "transfo_xl", + "tvlt", + "van", + "vit_hybrid", + "xlm_prophetnet", +] + +SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str]( + [ + ("openai-gpt", "openai"), + ("data2vec-audio", "data2vec"), + ("data2vec-text", "data2vec"), + ("data2vec-vision", "data2vec"), + ("donut-swin", "donut"), + ("kosmos-2", "kosmos2"), + ("kosmos-2.5", "kosmos2_5"), + ("maskformer-swin", "maskformer"), + ("xclip", "x_clip"), + ("clip_vision_model", "clip"), + ("qwen2_audio_encoder", "qwen2_audio"), + ("voxtral_encoder", "voxtral"), + ("clip_text_model", "clip"), + ("aria_text", "aria"), + ("gemma3_text", "gemma3"), + ("gemma3n_audio", "gemma3n"), + ("gemma3n_text", "gemma3n"), + ("gemma3n_vision", "gemma3n"), + ("glm4v_text", "glm4v"), + ("glm4v_moe_text", "glm4v_moe"), + ("idefics3_vision", "idefics3"), + ("siglip_vision_model", "siglip"), + ("siglip2_vision_model", "siglip2"), + ("aimv2_vision_model", "aimv2"), + ("smolvlm_vision", "smolvlm"), + ("chinese_clip_vision_model", "chinese_clip"), + ("rt_detr_resnet", "rt_detr"), + ("granitevision", "llava_next"), + ("internvl_vision", "internvl"), + ("qwen2_5_vl_text", "qwen2_5_vl"), + ("qwen2_vl_text", "qwen2_vl"), + ("qwen3_vl_text", "qwen3_vl"), + ("qwen3_vl_moe_text", "qwen3_vl_moe"), + ("sam_vision_model", "sam"), + ("sam2_vision_model", "sam2"), + ("edgetam_vision_model", "edgetam"), + ("sam2_hiera_det_model", "sam2"), + ("sam_hq_vision_model", "sam_hq"), + ("llama4_text", "llama4"), + ("blip_2_qformer", "blip_2"), + ("fastspeech2_conformer_with_hifigan", "fastspeech2_conformer"), + ("perception_encoder", "perception_lm"), + ("parakeet_encoder", "parakeet"), + ("parakeet_ctc", "parakeet"), + ] +) + + +def model_type_to_module_name(key) -> str: + """Converts a config key to the corresponding module.""" + # Special treatment + if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME: + key = SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key] + + if key in DEPRECATED_MODELS: + key = f"deprecated.{key}" + return key + + key = key.replace("-", "_") + if key in DEPRECATED_MODELS: + key = f"deprecated.{key}" + + return key + + +def config_class_to_model_type(config) -> Union[str, None]: + """Converts a config class name to the corresponding model type""" + for key, cls in CONFIG_MAPPING_NAMES.items(): + if cls == config: + return key + # if key not found check in extra content + for key, cls in CONFIG_MAPPING._extra_content.items(): + if cls.__name__ == config: + return key + return None + + +class _LazyConfigMapping(OrderedDict[str, type[PretrainedConfig]]): + """ + A dictionary that lazily load its values when they are requested. + """ + + def __init__(self, mapping) -> None: + self._mapping = mapping + self._extra_content = {} + self._modules = {} + + def __getitem__(self, key: str) -> type[PretrainedConfig]: + if key in self._extra_content: + return self._extra_content[key] + if key not in self._mapping: + raise KeyError(key) + value = self._mapping[key] + module_name = model_type_to_module_name(key) + if module_name not in self._modules: + self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models") + if hasattr(self._modules[module_name], value): + return getattr(self._modules[module_name], value) + + # Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the + # object at the top level. + transformers_module = importlib.import_module("transformers") + return getattr(transformers_module, value) + + def keys(self) -> list[str]: + return list(self._mapping.keys()) + list(self._extra_content.keys()) + + def values(self) -> list[type[PretrainedConfig]]: + return [self[k] for k in self._mapping] + list(self._extra_content.values()) + + def items(self) -> list[tuple[str, type[PretrainedConfig]]]: + return [(k, self[k]) for k in self._mapping] + list(self._extra_content.items()) + + def __iter__(self) -> Iterator[str]: + return iter(list(self._mapping.keys()) + list(self._extra_content.keys())) + + def __contains__(self, item: object) -> bool: + return item in self._mapping or item in self._extra_content + + def register(self, key: str, value: type[PretrainedConfig], exist_ok=False) -> None: + """ + Register a new configuration in this mapping. + """ + if key in self._mapping and not exist_ok: + raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.") + self._extra_content[key] = value + + +CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES) + + +class _LazyLoadAllMappings(OrderedDict[str, str]): + """ + A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values, + etc.) + + Args: + mapping: The mapping to load. + """ + + def __init__(self, mapping): + self._mapping = mapping + self._initialized = False + self._data = {} + + def _initialize(self): + if self._initialized: + return + + for model_type, map_name in self._mapping.items(): + module_name = model_type_to_module_name(model_type) + module = importlib.import_module(f".{module_name}", "transformers.models") + mapping = getattr(module, map_name) + self._data.update(mapping) + + self._initialized = True + + def __getitem__(self, key): + self._initialize() + return self._data[key] + + def keys(self) -> KeysView[str]: + self._initialize() + return self._data.keys() + + def values(self) -> ValuesView[str]: + self._initialize() + return self._data.values() + + def items(self) -> KeysView[str]: + self._initialize() + return self._data.keys() + + def __iter__(self) -> Iterator[str]: + self._initialize() + return iter(self._data) + + def __contains__(self, item: object) -> bool: + self._initialize() + return item in self._data + + +def _get_class_name(model_class: Union[str, list[str]]): + if isinstance(model_class, (list, tuple)): + return " or ".join([f"[`{c}`]" for c in model_class if c is not None]) + return f"[`{model_class}`]" + + +def _list_model_options(indent, config_to_class=None, use_model_types=True): + if config_to_class is None and not use_model_types: + raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.") + if use_model_types: + if config_to_class is None: + model_type_to_name = {model_type: f"[`{config}`]" for model_type, config in CONFIG_MAPPING_NAMES.items()} + else: + model_type_to_name = { + model_type: _get_class_name(model_class) + for model_type, model_class in config_to_class.items() + if model_type in MODEL_NAMES_MAPPING + } + lines = [ + f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)" + for model_type in sorted(model_type_to_name.keys()) + ] + else: + config_to_name = { + CONFIG_MAPPING_NAMES[config]: _get_class_name(clas) + for config, clas in config_to_class.items() + if config in CONFIG_MAPPING_NAMES + } + config_to_model_name = { + config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items() + } + lines = [ + f"{indent}- [`{config_name}`] configuration class:" + f" {config_to_name[config_name]} ({config_to_model_name[config_name]} model)" + for config_name in sorted(config_to_name.keys()) + ] + return "\n".join(lines) + + +def replace_list_option_in_docstrings( + config_to_class=None, use_model_types: bool = True +) -> Callable[[_CallableT], _CallableT]: + def docstring_decorator(fn): + docstrings = fn.__doc__ + if docstrings is None: + # Example: -OO + return fn + lines = docstrings.split("\n") + i = 0 + while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None: + i += 1 + if i < len(lines): + indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0] + if use_model_types: + indent = f"{indent} " + lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types) + docstrings = "\n".join(lines) + else: + raise ValueError( + f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current" + f" docstring is:\n{docstrings}" + ) + fn.__doc__ = docstrings + return fn + + return docstring_decorator + + +class AutoConfig: + r""" + This is a generic configuration class that will be instantiated as one of the configuration classes of the library + when created with the [`~AutoConfig.from_pretrained`] class method. + + This class cannot be instantiated directly using `__init__()` (throws an error). + """ + + def __init__(self) -> None: + raise OSError( + "AutoConfig is designed to be instantiated " + "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method." + ) + + @classmethod + def for_model(cls, model_type: str, *args, **kwargs) -> PretrainedConfig: + if model_type in CONFIG_MAPPING: + config_class = CONFIG_MAPPING[model_type] + return config_class(*args, **kwargs) + raise ValueError( + f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}" + ) + + @classmethod + @replace_list_option_in_docstrings() + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], **kwargs): + r""" + Instantiate one of the configuration classes of the library from a pretrained model configuration. + + The configuration class to instantiate is selected based on the `model_type` property of the config object that + is loaded, or when it's missing, by falling back to using pattern matching on `pretrained_model_name_or_path`: + + List options + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. + - A path to a *directory* containing a configuration file saved using the + [`~PretrainedConfig.save_pretrained`] method, or the [`~PreTrainedModel.save_pretrained`] method, + e.g., `./my_model_directory/`. + - A path or url to a saved configuration JSON *file*, e.g., + `./my_model_directory/configuration.json`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download the model weights and configuration files and override the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + If `False`, then this function returns just the final configuration object. + + If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a + dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the + part of `kwargs` which has not been used to update `config` and is otherwise ignored. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + kwargs(additional keyword arguments, *optional*): + The values in kwargs of any keys which are configuration attributes will be used to override the loaded + values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled + by the `return_unused_kwargs` keyword parameter. + + Examples: + + ```python + >>> from transformers import AutoConfig + + >>> # Download configuration from huggingface.co and cache. + >>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased") + + >>> # Download configuration from huggingface.co (user-uploaded) and cache. + >>> config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased") + + >>> # If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*). + >>> config = AutoConfig.from_pretrained("./test/bert_saved_model/") + + >>> # Load a specific configuration file. + >>> config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json") + + >>> # Change some config attributes when loading a pretrained config. + >>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False) + >>> config.output_attentions + True + + >>> config, unused_kwargs = AutoConfig.from_pretrained( + ... "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True + ... ) + >>> config.output_attentions + True + + >>> unused_kwargs + {'foo': False} + ``` + """ + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token") is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + kwargs["_from_auto"] = True + kwargs["name_or_path"] = pretrained_model_name_or_path + trust_remote_code = kwargs.pop("trust_remote_code", None) + code_revision = kwargs.pop("code_revision", None) + + config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) + has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"] + has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING + if has_remote_code: + class_ref = config_dict["auto_map"]["AutoConfig"] + if "--" in class_ref: + upstream_repo = class_ref.split("--")[0] + else: + upstream_repo = None + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo + ) + + if has_remote_code and trust_remote_code: + config_class = get_class_from_dynamic_module( + class_ref, pretrained_model_name_or_path, code_revision=code_revision, **kwargs + ) + config_class.register_for_auto_class() + return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + elif "model_type" in config_dict: + # Apply heuristic: if model_type is mistral but layer_types is present, treat as ministral + if config_dict["model_type"] == "mistral" and "layer_types" in config_dict: + logger.info( + "Detected mistral model with layer_types, treating as ministral for alternating attention compatibility. " + ) + config_dict["model_type"] = "ministral" + + try: + config_class = CONFIG_MAPPING[config_dict["model_type"]] + except KeyError: + raise ValueError( + f"The checkpoint you are trying to load has model type `{config_dict['model_type']}` " + "but Transformers does not recognize this architecture. This could be because of an " + "issue with the checkpoint, or because your version of Transformers is out of date.\n\n" + "You can update Transformers with the command `pip install --upgrade transformers`. If this " + "does not work, and the checkpoint is very new, then there may not be a release version " + "that supports this model yet. In this case, you can get the most up-to-date code by installing " + "Transformers from source with the command " + "`pip install git+https://github.com/huggingface/transformers.git`" + ) + return config_class.from_dict(config_dict, **unused_kwargs) + else: + # Fallback: use pattern matching on the string. + # We go from longer names to shorter names to catch roberta before bert (for instance) + for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True): + if pattern in str(pretrained_model_name_or_path): + return CONFIG_MAPPING[pattern].from_dict(config_dict, **unused_kwargs) + + raise ValueError( + f"Unrecognized model in {pretrained_model_name_or_path}. " + f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings " + f"in its name: {', '.join(CONFIG_MAPPING.keys())}" + ) + + @staticmethod + def register(model_type, config, exist_ok=False) -> None: + """ + Register a new configuration for this class. + + Args: + model_type (`str`): The model type like "bert" or "gpt". + config ([`PretrainedConfig`]): The config to register. + """ + if issubclass(config, PretrainedConfig) and config.model_type != model_type: + raise ValueError( + "The config you are passing has a `model_type` attribute that is not consistent with the model type " + f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they " + "match!" + ) + CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok) + + +__all__ = ["CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/feature_extraction_auto.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/feature_extraction_auto.py new file mode 100644 index 0000000000000000000000000000000000000000..6d4c4f554d9dcbaceadff65ff84d5fbe818fa96a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/feature_extraction_auto.py @@ -0,0 +1,422 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""AutoFeatureExtractor class.""" + +import importlib +import json +import os +import warnings +from collections import OrderedDict +from typing import Optional, Union + +# Build the list of all feature extractors +from ...configuration_utils import PretrainedConfig +from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from ...feature_extraction_utils import FeatureExtractionMixin +from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, cached_file, logging +from .auto_factory import _LazyAutoMapping +from .configuration_auto import ( + CONFIG_MAPPING_NAMES, + AutoConfig, + model_type_to_module_name, + replace_list_option_in_docstrings, +) + + +logger = logging.get_logger(__name__) + +FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( + [ + ("audio-spectrogram-transformer", "ASTFeatureExtractor"), + ("beit", "BeitFeatureExtractor"), + ("chinese_clip", "ChineseCLIPFeatureExtractor"), + ("clap", "ClapFeatureExtractor"), + ("clip", "CLIPFeatureExtractor"), + ("clipseg", "ViTFeatureExtractor"), + ("clvp", "ClvpFeatureExtractor"), + ("conditional_detr", "ConditionalDetrFeatureExtractor"), + ("convnext", "ConvNextFeatureExtractor"), + ("cvt", "ConvNextFeatureExtractor"), + ("dac", "DacFeatureExtractor"), + ("data2vec-audio", "Wav2Vec2FeatureExtractor"), + ("data2vec-vision", "BeitFeatureExtractor"), + ("deformable_detr", "DeformableDetrFeatureExtractor"), + ("deit", "DeiTFeatureExtractor"), + ("detr", "DetrFeatureExtractor"), + ("dia", "DiaFeatureExtractor"), + ("dinat", "ViTFeatureExtractor"), + ("donut-swin", "DonutFeatureExtractor"), + ("dpt", "DPTFeatureExtractor"), + ("encodec", "EncodecFeatureExtractor"), + ("flava", "FlavaFeatureExtractor"), + ("gemma3n", "Gemma3nAudioFeatureExtractor"), + ("glpn", "GLPNFeatureExtractor"), + ("granite_speech", "GraniteSpeechFeatureExtractor"), + ("groupvit", "CLIPFeatureExtractor"), + ("hubert", "Wav2Vec2FeatureExtractor"), + ("imagegpt", "ImageGPTFeatureExtractor"), + ("kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"), + ("layoutlmv2", "LayoutLMv2FeatureExtractor"), + ("layoutlmv3", "LayoutLMv3FeatureExtractor"), + ("levit", "LevitFeatureExtractor"), + ("maskformer", "MaskFormerFeatureExtractor"), + ("mctct", "MCTCTFeatureExtractor"), + ("mimi", "EncodecFeatureExtractor"), + ("mobilenet_v1", "MobileNetV1FeatureExtractor"), + ("mobilenet_v2", "MobileNetV2FeatureExtractor"), + ("mobilevit", "MobileViTFeatureExtractor"), + ("moonshine", "Wav2Vec2FeatureExtractor"), + ("moshi", "EncodecFeatureExtractor"), + ("nat", "ViTFeatureExtractor"), + ("owlvit", "OwlViTFeatureExtractor"), + ("parakeet_ctc", "ParakeetFeatureExtractor"), + ("parakeet_encoder", "ParakeetFeatureExtractor"), + ("perceiver", "PerceiverFeatureExtractor"), + ("phi4_multimodal", "Phi4MultimodalFeatureExtractor"), + ("poolformer", "PoolFormerFeatureExtractor"), + ("pop2piano", "Pop2PianoFeatureExtractor"), + ("regnet", "ConvNextFeatureExtractor"), + ("resnet", "ConvNextFeatureExtractor"), + ("seamless_m4t", "SeamlessM4TFeatureExtractor"), + ("seamless_m4t_v2", "SeamlessM4TFeatureExtractor"), + ("segformer", "SegformerFeatureExtractor"), + ("sew", "Wav2Vec2FeatureExtractor"), + ("sew-d", "Wav2Vec2FeatureExtractor"), + ("speech_to_text", "Speech2TextFeatureExtractor"), + ("speecht5", "SpeechT5FeatureExtractor"), + ("swiftformer", "ViTFeatureExtractor"), + ("swin", "ViTFeatureExtractor"), + ("swinv2", "ViTFeatureExtractor"), + ("table-transformer", "DetrFeatureExtractor"), + ("timesformer", "VideoMAEFeatureExtractor"), + ("tvlt", "TvltFeatureExtractor"), + ("unispeech", "Wav2Vec2FeatureExtractor"), + ("unispeech-sat", "Wav2Vec2FeatureExtractor"), + ("univnet", "UnivNetFeatureExtractor"), + ("van", "ConvNextFeatureExtractor"), + ("videomae", "VideoMAEFeatureExtractor"), + ("vilt", "ViltFeatureExtractor"), + ("vit", "ViTFeatureExtractor"), + ("vit_mae", "ViTFeatureExtractor"), + ("vit_msn", "ViTFeatureExtractor"), + ("wav2vec2", "Wav2Vec2FeatureExtractor"), + ("wav2vec2-bert", "Wav2Vec2FeatureExtractor"), + ("wav2vec2-conformer", "Wav2Vec2FeatureExtractor"), + ("wavlm", "Wav2Vec2FeatureExtractor"), + ("whisper", "WhisperFeatureExtractor"), + ("xclip", "CLIPFeatureExtractor"), + ("xcodec", "DacFeatureExtractor"), + ("yolos", "YolosFeatureExtractor"), + ] +) + +FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES) + + +def feature_extractor_class_from_name(class_name: str): + for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items(): + if class_name in extractors: + module_name = model_type_to_module_name(module_name) + + module = importlib.import_module(f".{module_name}", "transformers.models") + try: + return getattr(module, class_name) + except AttributeError: + continue + + for extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.values(): + if getattr(extractor, "__name__", None) == class_name: + return extractor + + # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main + # init and we return the proper dummy to get an appropriate error message. + main_module = importlib.import_module("transformers") + if hasattr(main_module, class_name): + return getattr(main_module, class_name) + + return None + + +def get_feature_extractor_config( + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: Optional[bool] = None, + proxies: Optional[dict[str, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + **kwargs, +): + """ + Loads the tokenizer configuration from a pretrained model tokenizer configuration. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + + + + Passing `token=True` is required when you want to use a private model. + + + + Returns: + `Dict`: The configuration of the tokenizer. + + Examples: + + ```python + # Download configuration from huggingface.co and cache. + tokenizer_config = get_tokenizer_config("google-bert/bert-base-uncased") + # This model does not have a tokenizer config so the result will be an empty dict. + tokenizer_config = get_tokenizer_config("FacebookAI/xlm-roberta-base") + + # Save a pretrained tokenizer locally and you can reload its config + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased") + tokenizer.save_pretrained("tokenizer-test") + tokenizer_config = get_tokenizer_config("tokenizer-test") + ```""" + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + token = use_auth_token + + resolved_config_file = cached_file( + pretrained_model_name_or_path, + FEATURE_EXTRACTOR_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) + if resolved_config_file is None: + logger.info( + "Could not locate the feature extractor configuration file, will try to use the model config instead." + ) + return {} + + with open(resolved_config_file, encoding="utf-8") as reader: + return json.load(reader) + + +class AutoFeatureExtractor: + r""" + This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the + library when created with the [`AutoFeatureExtractor.from_pretrained`] class method. + + This class cannot be instantiated directly using `__init__()` (throws an error). + """ + + def __init__(self): + raise OSError( + "AutoFeatureExtractor is designed to be instantiated " + "using the `AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)` method." + ) + + @classmethod + @replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING_NAMES) + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" + Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary. + + The feature extractor class to instantiate is selected based on the `model_type` property of the config object + (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's + missing, by falling back to using pattern matching on `pretrained_model_name_or_path`: + + List options + + Params: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a feature extractor file saved using the + [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g., + `./my_model_directory/`. + - a path or url to a saved feature extractor JSON *file*, e.g., + `./my_model_directory/preprocessor_config.json`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model feature extractor should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the feature extractor files and override the cached versions + if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + If `False`, then this function returns just the final feature extractor object. If `True`, then this + functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary + consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of + `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + kwargs (`dict[str, Any]`, *optional*): + The values in kwargs of any keys which are feature extractor attributes will be used to override the + loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is + controlled by the `return_unused_kwargs` keyword parameter. + + + + Passing `token=True` is required when you want to use a private model. + + + + Examples: + + ```python + >>> from transformers import AutoFeatureExtractor + + >>> # Download feature extractor from huggingface.co and cache. + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") + + >>> # If feature extractor files are in a directory (e.g. feature extractor was saved using *save_pretrained('./test/saved_model/')*) + >>> # feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/") + ```""" + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token") is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + config = kwargs.pop("config", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + kwargs["_from_auto"] = True + + config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) + feature_extractor_class = config_dict.get("feature_extractor_type", None) + feature_extractor_auto_map = None + if "AutoFeatureExtractor" in config_dict.get("auto_map", {}): + feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"] + + # If we don't find the feature extractor class in the feature extractor config, let's try the model config. + if feature_extractor_class is None and feature_extractor_auto_map is None: + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + # It could be in `config.feature_extractor_type`` + feature_extractor_class = getattr(config, "feature_extractor_type", None) + if hasattr(config, "auto_map") and "AutoFeatureExtractor" in config.auto_map: + feature_extractor_auto_map = config.auto_map["AutoFeatureExtractor"] + + if feature_extractor_class is not None: + feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class) + + has_remote_code = feature_extractor_auto_map is not None + has_local_code = feature_extractor_class is not None or type(config) in FEATURE_EXTRACTOR_MAPPING + if has_remote_code: + if "--" in feature_extractor_auto_map: + upstream_repo = feature_extractor_auto_map.split("--")[0] + else: + upstream_repo = None + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo + ) + + if has_remote_code and trust_remote_code: + feature_extractor_class = get_class_from_dynamic_module( + feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs + ) + _ = kwargs.pop("code_revision", None) + feature_extractor_class.register_for_auto_class() + return feature_extractor_class.from_dict(config_dict, **kwargs) + elif feature_extractor_class is not None: + return feature_extractor_class.from_dict(config_dict, **kwargs) + # Last try: we use the FEATURE_EXTRACTOR_MAPPING. + elif type(config) in FEATURE_EXTRACTOR_MAPPING: + feature_extractor_class = FEATURE_EXTRACTOR_MAPPING[type(config)] + return feature_extractor_class.from_dict(config_dict, **kwargs) + + raise ValueError( + f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a " + f"`feature_extractor_type` key in its {FEATURE_EXTRACTOR_NAME} of {CONFIG_NAME}, or one of the following " + f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES)}" + ) + + @staticmethod + def register(config_class, feature_extractor_class, exist_ok=False): + """ + Register a new feature extractor for this class. + + Args: + config_class ([`PretrainedConfig`]): + The configuration corresponding to the model to register. + feature_extractor_class ([`FeatureExtractorMixin`]): The feature extractor to register. + """ + FEATURE_EXTRACTOR_MAPPING.register(config_class, feature_extractor_class, exist_ok=exist_ok) + + +__all__ = ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/image_processing_auto.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/image_processing_auto.py new file mode 100644 index 0000000000000000000000000000000000000000..4b71712dfc7bbb8c3e98fe87464151a4a579f695 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/image_processing_auto.py @@ -0,0 +1,688 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""AutoImageProcessor class.""" + +import importlib +import json +import os +import warnings +from collections import OrderedDict +from typing import TYPE_CHECKING, Optional, Union + +# Build the list of all image processors +from ...configuration_utils import PretrainedConfig +from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from ...image_processing_utils import ImageProcessingMixin +from ...image_processing_utils_fast import BaseImageProcessorFast +from ...utils import ( + CONFIG_NAME, + IMAGE_PROCESSOR_NAME, + cached_file, + is_timm_config_dict, + is_timm_local_checkpoint, + is_torchvision_available, + is_vision_available, + logging, +) +from ...utils.import_utils import requires +from .auto_factory import _LazyAutoMapping +from .configuration_auto import ( + CONFIG_MAPPING_NAMES, + AutoConfig, + model_type_to_module_name, + replace_list_option_in_docstrings, +) + + +logger = logging.get_logger(__name__) + + +FORCE_FAST_IMAGE_PROCESSOR = ["Qwen2VLImageProcessor"] + + +if TYPE_CHECKING: + # This significantly improves completion suggestion performance when + # the transformers package is used with Microsoft's Pylance language server. + IMAGE_PROCESSOR_MAPPING_NAMES: OrderedDict[str, tuple[Optional[str], Optional[str]]] = OrderedDict() +else: + IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( + [ + ("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")), + ("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")), + ("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")), + ("aria", ("AriaImageProcessor", None)), + ("beit", ("BeitImageProcessor", "BeitImageProcessorFast")), + ("bit", ("BitImageProcessor", "BitImageProcessorFast")), + ("blip", ("BlipImageProcessor", "BlipImageProcessorFast")), + ("blip-2", ("BlipImageProcessor", "BlipImageProcessorFast")), + ("bridgetower", ("BridgeTowerImageProcessor", "BridgeTowerImageProcessorFast")), + ("chameleon", ("ChameleonImageProcessor", "ChameleonImageProcessorFast")), + ("chinese_clip", ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast")), + ("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")), + ("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")), + ("cohere2_vision", (None, "Cohere2VisionImageProcessorFast")), + ("conditional_detr", ("ConditionalDetrImageProcessor", "ConditionalDetrImageProcessorFast")), + ("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), + ("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), + ("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), + ("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")), + ("deepseek_vl", ("DeepseekVLImageProcessor", "DeepseekVLImageProcessorFast")), + ("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor", "DeepseekVLHybridImageProcessorFast")), + ("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")), + ("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")), + ("depth_anything", ("DPTImageProcessor", "DPTImageProcessorFast")), + ("depth_pro", ("DepthProImageProcessor", "DepthProImageProcessorFast")), + ("deta", ("DetaImageProcessor", None)), + ("detr", ("DetrImageProcessor", "DetrImageProcessorFast")), + ("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")), + ("dinov2", ("BitImageProcessor", "BitImageProcessorFast")), + ("dinov3_vit", (None, "DINOv3ViTImageProcessorFast")), + ("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")), + ("dpt", ("DPTImageProcessor", "DPTImageProcessorFast")), + ("edgetam", (None, "Sam2ImageProcessorFast")), + ("efficientformer", ("EfficientFormerImageProcessor", None)), + ("efficientloftr", ("EfficientLoFTRImageProcessor", "EfficientLoFTRImageProcessorFast")), + ("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")), + ("eomt", ("EomtImageProcessor", "EomtImageProcessorFast")), + ("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")), + ("focalnet", ("BitImageProcessor", "BitImageProcessorFast")), + ("fuyu", ("FuyuImageProcessor", None)), + ("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), + ("gemma3n", ("SiglipImageProcessor", "SiglipImageProcessorFast")), + ("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")), + ("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")), + ("glpn", ("GLPNImageProcessor", None)), + ("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")), + ("grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")), + ("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")), + ("hiera", ("BitImageProcessor", "BitImageProcessorFast")), + ("idefics", ("IdeficsImageProcessor", None)), + ("idefics2", ("Idefics2ImageProcessor", "Idefics2ImageProcessorFast")), + ("idefics3", ("Idefics3ImageProcessor", "Idefics3ImageProcessorFast")), + ("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")), + ("imagegpt", ("ImageGPTImageProcessor", "ImageGPTImageProcessorFast")), + ("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")), + ("instructblipvideo", ("InstructBlipVideoImageProcessor", None)), + ("janus", ("JanusImageProcessor", "JanusImageProcessorFast")), + ("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")), + ("kosmos-2.5", ("Kosmos2_5ImageProcessor", "Kosmos2_5ImageProcessorFast")), + ("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")), + ("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")), + ("levit", ("LevitImageProcessor", "LevitImageProcessorFast")), + ("lfm2_vl", (None, "Lfm2VlImageProcessorFast")), + ("lightglue", ("LightGlueImageProcessor", None)), + ("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")), + ("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")), + ("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")), + ("llava_next_video", ("LlavaNextVideoImageProcessor", None)), + ("llava_onevision", ("LlavaOnevisionImageProcessor", "LlavaOnevisionImageProcessorFast")), + ("mask2former", ("Mask2FormerImageProcessor", "Mask2FormerImageProcessorFast")), + ("maskformer", ("MaskFormerImageProcessor", "MaskFormerImageProcessorFast")), + ("metaclip_2", ("CLIPImageProcessor", "CLIPImageProcessorFast")), + ("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")), + ("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")), + ("mlcd", ("CLIPImageProcessor", "CLIPImageProcessorFast")), + ("mllama", ("MllamaImageProcessor", None)), + ("mm-grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")), + ("mobilenet_v1", ("MobileNetV1ImageProcessor", "MobileNetV1ImageProcessorFast")), + ("mobilenet_v2", ("MobileNetV2ImageProcessor", "MobileNetV2ImageProcessorFast")), + ("mobilevit", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")), + ("mobilevitv2", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")), + ("nat", ("ViTImageProcessor", "ViTImageProcessorFast")), + ("nougat", ("NougatImageProcessor", "NougatImageProcessorFast")), + ("oneformer", ("OneFormerImageProcessor", "OneFormerImageProcessorFast")), + ("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")), + ("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")), + ("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")), + ("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")), + ("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")), + ("perception_lm", (None, "PerceptionLMImageProcessorFast")), + ("phi4_multimodal", (None, "Phi4MultimodalImageProcessorFast")), + ("pix2struct", ("Pix2StructImageProcessor", None)), + ("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")), + ("poolformer", ("PoolFormerImageProcessor", "PoolFormerImageProcessorFast")), + ("prompt_depth_anything", ("PromptDepthAnythingImageProcessor", "PromptDepthAnythingImageProcessorFast")), + ("pvt", ("PvtImageProcessor", "PvtImageProcessorFast")), + ("pvt_v2", ("PvtImageProcessor", "PvtImageProcessorFast")), + ("qwen2_5_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), + ("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), + ("qwen3_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), + ("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), + ("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), + ("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")), + ("sam", ("SamImageProcessor", "SamImageProcessorFast")), + ("sam2", (None, "Sam2ImageProcessorFast")), + ("sam_hq", ("SamImageProcessor", "SamImageProcessorFast")), + ("segformer", ("SegformerImageProcessor", "SegformerImageProcessorFast")), + ("seggpt", ("SegGptImageProcessor", None)), + ("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), + ("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")), + ("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")), + ("smolvlm", ("SmolVLMImageProcessor", "SmolVLMImageProcessorFast")), + ("superglue", ("SuperGlueImageProcessor", None)), + ("superpoint", ("SuperPointImageProcessor", "SuperPointImageProcessorFast")), + ("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")), + ("swin", ("ViTImageProcessor", "ViTImageProcessorFast")), + ("swin2sr", ("Swin2SRImageProcessor", "Swin2SRImageProcessorFast")), + ("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")), + ("table-transformer", ("DetrImageProcessor", "DetrImageProcessorFast")), + ("textnet", ("TextNetImageProcessor", "TextNetImageProcessorFast")), + ("timesformer", ("VideoMAEImageProcessor", None)), + ("timm_wrapper", ("TimmWrapperImageProcessor", None)), + ("tvlt", ("TvltImageProcessor", None)), + ("tvp", ("TvpImageProcessor", "TvpImageProcessorFast")), + ("udop", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")), + ("upernet", ("SegformerImageProcessor", "SegformerImageProcessorFast")), + ("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), + ("videomae", ("VideoMAEImageProcessor", None)), + ("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")), + ("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")), + ("vit", ("ViTImageProcessor", "ViTImageProcessorFast")), + ("vit_hybrid", ("ViTHybridImageProcessor", None)), + ("vit_mae", ("ViTImageProcessor", "ViTImageProcessorFast")), + ("vit_msn", ("ViTImageProcessor", "ViTImageProcessorFast")), + ("vitmatte", ("VitMatteImageProcessor", "VitMatteImageProcessorFast")), + ("xclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")), + ("yolos", ("YolosImageProcessor", "YolosImageProcessorFast")), + ("zoedepth", ("ZoeDepthImageProcessor", "ZoeDepthImageProcessorFast")), + ] + ) + +# Override to None if the packages are not available +for model_type, (slow_class, fast_class) in IMAGE_PROCESSOR_MAPPING_NAMES.items(): + if not is_vision_available(): + slow_class = None + if not is_torchvision_available(): + fast_class = None + + IMAGE_PROCESSOR_MAPPING_NAMES[model_type] = (slow_class, fast_class) + +IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES) + + +def get_image_processor_class_from_name(class_name: str): + if class_name == "BaseImageProcessorFast": + return BaseImageProcessorFast + + for module_name, extractors in IMAGE_PROCESSOR_MAPPING_NAMES.items(): + if class_name in extractors: + module_name = model_type_to_module_name(module_name) + + module = importlib.import_module(f".{module_name}", "transformers.models") + try: + return getattr(module, class_name) + except AttributeError: + continue + + for extractors in IMAGE_PROCESSOR_MAPPING._extra_content.values(): + for extractor in extractors: + if getattr(extractor, "__name__", None) == class_name: + return extractor + + # We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main + # init and we return the proper dummy to get an appropriate error message. + main_module = importlib.import_module("transformers") + if hasattr(main_module, class_name): + return getattr(main_module, class_name) + + return None + + +def get_image_processor_config( + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: Optional[bool] = None, + proxies: Optional[dict[str, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + **kwargs, +): + """ + Loads the image processor configuration from a pretrained model image processor configuration. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the image processor configuration from local files. + + + + Passing `token=True` is required when you want to use a private model. + + + + Returns: + `Dict`: The configuration of the image processor. + + Examples: + + ```python + # Download configuration from huggingface.co and cache. + image_processor_config = get_image_processor_config("google-bert/bert-base-uncased") + # This model does not have a image processor config so the result will be an empty dict. + image_processor_config = get_image_processor_config("FacebookAI/xlm-roberta-base") + + # Save a pretrained image processor locally and you can reload its config + from transformers import AutoTokenizer + + image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") + image_processor.save_pretrained("image-processor-test") + image_processor_config = get_image_processor_config("image-processor-test") + ```""" + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + token = use_auth_token + + resolved_config_file = cached_file( + pretrained_model_name_or_path, + IMAGE_PROCESSOR_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) + if resolved_config_file is None: + logger.info( + "Could not locate the image processor configuration file, will try to use the model config instead." + ) + return {} + + with open(resolved_config_file, encoding="utf-8") as reader: + return json.load(reader) + + +def _warning_fast_image_processor_available(fast_class): + logger.warning( + f"Fast image processor class {fast_class} is available for this model. " + "Using slow image processor class. To use the fast image processor class set `use_fast=True`." + ) + + +@requires(backends=("vision",)) +class AutoImageProcessor: + r""" + This is a generic image processor class that will be instantiated as one of the image processor classes of the + library when created with the [`AutoImageProcessor.from_pretrained`] class method. + + This class cannot be instantiated directly using `__init__()` (throws an error). + """ + + def __init__(self): + raise OSError( + "AutoImageProcessor is designed to be instantiated " + "using the `AutoImageProcessor.from_pretrained(pretrained_model_name_or_path)` method." + ) + + @classmethod + @replace_list_option_in_docstrings(IMAGE_PROCESSOR_MAPPING_NAMES) + def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + r""" + Instantiate one of the image processor classes of the library from a pretrained model vocabulary. + + The image processor class to instantiate is selected based on the `model_type` property of the config object + (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's + missing, by falling back to using pattern matching on `pretrained_model_name_or_path`: + + List options + + Params: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained image_processor hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a image processor file saved using the + [`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g., + `./my_model_directory/`. + - a path or url to a saved image processor JSON *file*, e.g., + `./my_model_directory/preprocessor_config.json`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model image processor should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the image processor files and override the cached versions if + they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + use_fast (`bool`, *optional*, defaults to `False`): + Use a fast torchvision-base image processor if it is supported for a given model. + If a fast image processor is not available for a given model, a normal numpy-based image processor + is returned instead. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + If `False`, then this function returns just the final image processor object. If `True`, then this + functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary + consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of + `kwargs` which has not been used to update `image_processor` and is otherwise ignored. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + image_processor_filename (`str`, *optional*, defaults to `"config.json"`): + The name of the file in the model directory to use for the image processor config. + kwargs (`dict[str, Any]`, *optional*): + The values in kwargs of any keys which are image processor attributes will be used to override the + loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is + controlled by the `return_unused_kwargs` keyword parameter. + + + + Passing `token=True` is required when you want to use a private model. + + + + Examples: + + ```python + >>> from transformers import AutoImageProcessor + + >>> # Download image processor from huggingface.co and cache. + >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") + + >>> # If image processor files are in a directory (e.g. image processor was saved using *save_pretrained('./test/saved_model/')*) + >>> # image_processor = AutoImageProcessor.from_pretrained("./test/saved_model/") + ```""" + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token") is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + config = kwargs.pop("config", None) + # TODO: @yoni, change in v4.48 (use_fast set to True by default) + use_fast = kwargs.pop("use_fast", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + kwargs["_from_auto"] = True + + # Resolve the image processor config filename + if "image_processor_filename" in kwargs: + image_processor_filename = kwargs.pop("image_processor_filename") + elif is_timm_local_checkpoint(pretrained_model_name_or_path): + image_processor_filename = CONFIG_NAME + else: + image_processor_filename = IMAGE_PROCESSOR_NAME + + # Load the image processor config + try: + # Main path for all transformers models and local TimmWrapper checkpoints + config_dict, _ = ImageProcessingMixin.get_image_processor_dict( + pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs + ) + except Exception as initial_exception: + # Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json` + # instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information + # except the model name, the only way to check if a remote checkpoint is a timm model is to try to + # load `config.json` and if it fails with some error, we raise the initial exception. + try: + config_dict, _ = ImageProcessingMixin.get_image_processor_dict( + pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs + ) + except Exception: + raise initial_exception + + # In case we have a config_dict, but it's not a timm config dict, we raise the initial exception, + # because only timm models have image processing in `config.json`. + if not is_timm_config_dict(config_dict): + raise initial_exception + + image_processor_type = config_dict.get("image_processor_type", None) + image_processor_auto_map = None + if "AutoImageProcessor" in config_dict.get("auto_map", {}): + image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"] + + # If we still don't have the image processor class, check if we're loading from a previous feature extractor config + # and if so, infer the image processor class from there. + if image_processor_type is None and image_processor_auto_map is None: + feature_extractor_class = config_dict.pop("feature_extractor_type", None) + if feature_extractor_class is not None: + image_processor_type = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor") + if "AutoFeatureExtractor" in config_dict.get("auto_map", {}): + feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"] + image_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "ImageProcessor") + + # If we don't find the image processor class in the image processor config, let's try the model config. + if image_processor_type is None and image_processor_auto_map is None: + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + **kwargs, + ) + # It could be in `config.image_processor_type`` + image_processor_type = getattr(config, "image_processor_type", None) + if hasattr(config, "auto_map") and "AutoImageProcessor" in config.auto_map: + image_processor_auto_map = config.auto_map["AutoImageProcessor"] + + image_processor_class = None + # TODO: @yoni, change logic in v4.52 (when use_fast set to True by default) + if image_processor_type is not None: + # if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor. + if use_fast is None: + use_fast = image_processor_type.endswith("Fast") + if not use_fast and image_processor_type in FORCE_FAST_IMAGE_PROCESSOR and is_torchvision_available(): + use_fast = True + logger.warning_once( + f"The image processor of type `{image_processor_type}` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. " + "This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. " + "Note that this behavior will be extended to all models in a future release." + ) + if not use_fast: + logger.warning_once( + "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. " + "`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. " + "This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`." + ) + if use_fast and not image_processor_type.endswith("Fast"): + image_processor_type += "Fast" + if use_fast and not is_torchvision_available(): + # check if there is a slow image processor class to fallback to + image_processor_class = get_image_processor_class_from_name(image_processor_type[:-4]) + if image_processor_class is None: + raise ValueError( + f"`{image_processor_type}` requires `torchvision` to be installed. Please install `torchvision` and try again." + ) + logger.warning_once( + "Using `use_fast=True` but `torchvision` is not available. Falling back to the slow image processor." + ) + use_fast = False + if use_fast: + for image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.values(): + if image_processor_type in image_processors: + break + else: + image_processor_type = image_processor_type[:-4] + use_fast = False + logger.warning_once( + "`use_fast` is set to `True` but the image processor class does not have a fast version. " + " Falling back to the slow version." + ) + image_processor_class = get_image_processor_class_from_name(image_processor_type) + else: + image_processor_type_slow = image_processor_type.removesuffix("Fast") + image_processor_class = get_image_processor_class_from_name(image_processor_type_slow) + if image_processor_class is None and image_processor_type.endswith("Fast"): + raise ValueError( + f"`{image_processor_type}` does not have a slow version. Please set `use_fast=True` when instantiating the processor." + ) + + has_remote_code = image_processor_auto_map is not None + has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING + if has_remote_code: + if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple): + # In some configs, only the slow image processor class is stored + image_processor_auto_map = (image_processor_auto_map, None) + if use_fast and image_processor_auto_map[1] is not None: + class_ref = image_processor_auto_map[1] + else: + class_ref = image_processor_auto_map[0] + if "--" in class_ref: + upstream_repo = class_ref.split("--")[0] + else: + upstream_repo = None + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo + ) + + if has_remote_code and trust_remote_code: + if not use_fast and image_processor_auto_map[1] is not None: + _warning_fast_image_processor_available(image_processor_auto_map[1]) + + image_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) + _ = kwargs.pop("code_revision", None) + image_processor_class.register_for_auto_class() + return image_processor_class.from_dict(config_dict, **kwargs) + elif image_processor_class is not None: + return image_processor_class.from_dict(config_dict, **kwargs) + # Last try: we use the IMAGE_PROCESSOR_MAPPING. + elif type(config) in IMAGE_PROCESSOR_MAPPING: + image_processor_tuple = IMAGE_PROCESSOR_MAPPING[type(config)] + + image_processor_class_py, image_processor_class_fast = image_processor_tuple + + if not use_fast and image_processor_class_fast is not None: + _warning_fast_image_processor_available(image_processor_class_fast) + + if image_processor_class_fast and (use_fast or image_processor_class_py is None): + return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + else: + if image_processor_class_py is not None: + return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + else: + raise ValueError( + "This image processor cannot be instantiated. Please make sure you have `Pillow` installed." + ) + raise ValueError( + f"Unrecognized image processor in {pretrained_model_name_or_path}. Should have a " + f"`image_processor_type` key in its {IMAGE_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following " + f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in IMAGE_PROCESSOR_MAPPING_NAMES)}" + ) + + @staticmethod + def register( + config_class, + image_processor_class=None, + slow_image_processor_class=None, + fast_image_processor_class=None, + exist_ok=False, + ): + """ + Register a new image processor for this class. + + Args: + config_class ([`PretrainedConfig`]): + The configuration corresponding to the model to register. + image_processor_class ([`ImageProcessingMixin`]): The image processor to register. + """ + if image_processor_class is not None: + if slow_image_processor_class is not None: + raise ValueError("Cannot specify both image_processor_class and slow_image_processor_class") + warnings.warn( + "The image_processor_class argument is deprecated and will be removed in v4.42. Please use `slow_image_processor_class`, or `fast_image_processor_class` instead", + FutureWarning, + ) + slow_image_processor_class = image_processor_class + + if slow_image_processor_class is None and fast_image_processor_class is None: + raise ValueError("You need to specify either slow_image_processor_class or fast_image_processor_class") + if slow_image_processor_class is not None and issubclass(slow_image_processor_class, BaseImageProcessorFast): + raise ValueError("You passed a fast image processor in as the `slow_image_processor_class`.") + if fast_image_processor_class is not None and not issubclass( + fast_image_processor_class, BaseImageProcessorFast + ): + raise ValueError("The `fast_image_processor_class` should inherit from `BaseImageProcessorFast`.") + + if ( + slow_image_processor_class is not None + and fast_image_processor_class is not None + and issubclass(fast_image_processor_class, BaseImageProcessorFast) + and fast_image_processor_class.slow_image_processor_class != slow_image_processor_class + ): + raise ValueError( + "The fast processor class you are passing has a `slow_image_processor_class` attribute that is not " + "consistent with the slow processor class you passed (fast tokenizer has " + f"{fast_image_processor_class.slow_image_processor_class} and you passed {slow_image_processor_class}. Fix one of those " + "so they match!" + ) + + # Avoid resetting a set slow/fast image processor if we are passing just the other ones. + if config_class in IMAGE_PROCESSOR_MAPPING._extra_content: + existing_slow, existing_fast = IMAGE_PROCESSOR_MAPPING[config_class] + if slow_image_processor_class is None: + slow_image_processor_class = existing_slow + if fast_image_processor_class is None: + fast_image_processor_class = existing_fast + + IMAGE_PROCESSOR_MAPPING.register( + config_class, (slow_image_processor_class, fast_image_processor_class), exist_ok=exist_ok + ) + + +__all__ = ["IMAGE_PROCESSOR_MAPPING", "AutoImageProcessor"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_auto.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_auto.py new file mode 100644 index 0000000000000000000000000000000000000000..298834bebe9303b90a325d48f4be6ab732bfe051 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_auto.py @@ -0,0 +1,2382 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Auto Model class.""" + +import os +import warnings +from collections import OrderedDict +from typing import TYPE_CHECKING, Union + +from ...utils import logging +from .auto_factory import ( + _BaseAutoBackboneClass, + _BaseAutoModelClass, + _LazyAutoMapping, + auto_class_update, +) +from .configuration_auto import CONFIG_MAPPING_NAMES + + +if TYPE_CHECKING: + from ...generation import GenerationMixin + from ...modeling_utils import PreTrainedModel + + # class for better type annotations + class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): + pass + + +logger = logging.get_logger(__name__) + +MODEL_MAPPING_NAMES = OrderedDict( + [ + # Base model mapping + ("aimv2", "Aimv2Model"), + ("aimv2_vision_model", "Aimv2VisionModel"), + ("albert", "AlbertModel"), + ("align", "AlignModel"), + ("altclip", "AltCLIPModel"), + ("apertus", "ApertusModel"), + ("arcee", "ArceeModel"), + ("aria", "AriaModel"), + ("aria_text", "AriaTextModel"), + ("audio-spectrogram-transformer", "ASTModel"), + ("autoformer", "AutoformerModel"), + ("aya_vision", "AyaVisionModel"), + ("bamba", "BambaModel"), + ("bark", "BarkModel"), + ("bart", "BartModel"), + ("beit", "BeitModel"), + ("bert", "BertModel"), + ("bert-generation", "BertGenerationEncoder"), + ("big_bird", "BigBirdModel"), + ("bigbird_pegasus", "BigBirdPegasusModel"), + ("biogpt", "BioGptModel"), + ("bit", "BitModel"), + ("bitnet", "BitNetModel"), + ("blenderbot", "BlenderbotModel"), + ("blenderbot-small", "BlenderbotSmallModel"), + ("blip", "BlipModel"), + ("blip-2", "Blip2Model"), + ("blip_2_qformer", "Blip2QFormerModel"), + ("bloom", "BloomModel"), + ("blt", "BltModel"), + ("bridgetower", "BridgeTowerModel"), + ("bros", "BrosModel"), + ("camembert", "CamembertModel"), + ("canine", "CanineModel"), + ("chameleon", "ChameleonModel"), + ("chinese_clip", "ChineseCLIPModel"), + ("chinese_clip_vision_model", "ChineseCLIPVisionModel"), + ("clap", "ClapModel"), + ("clip", "CLIPModel"), + ("clip_text_model", "CLIPTextModel"), + ("clip_vision_model", "CLIPVisionModel"), + ("clipseg", "CLIPSegModel"), + ("clvp", "ClvpModelForConditionalGeneration"), + ("code_llama", "LlamaModel"), + ("codegen", "CodeGenModel"), + ("cohere", "CohereModel"), + ("cohere2", "Cohere2Model"), + ("cohere2_vision", "Cohere2VisionModel"), + ("conditional_detr", "ConditionalDetrModel"), + ("convbert", "ConvBertModel"), + ("convnext", "ConvNextModel"), + ("convnextv2", "ConvNextV2Model"), + ("cpmant", "CpmAntModel"), + ("csm", "CsmForConditionalGeneration"), + ("ctrl", "CTRLModel"), + ("cvt", "CvtModel"), + ("d_fine", "DFineModel"), + ("dab-detr", "DabDetrModel"), + ("dac", "DacModel"), + ("data2vec-audio", "Data2VecAudioModel"), + ("data2vec-text", "Data2VecTextModel"), + ("data2vec-vision", "Data2VecVisionModel"), + ("dbrx", "DbrxModel"), + ("deberta", "DebertaModel"), + ("deberta-v2", "DebertaV2Model"), + ("decision_transformer", "DecisionTransformerModel"), + ("deepseek_v2", "DeepseekV2Model"), + ("deepseek_v3", "DeepseekV3Model"), + ("deepseek_vl", "DeepseekVLModel"), + ("deepseek_vl_hybrid", "DeepseekVLHybridModel"), + ("deformable_detr", "DeformableDetrModel"), + ("deit", "DeiTModel"), + ("depth_pro", "DepthProModel"), + ("deta", "DetaModel"), + ("detr", "DetrModel"), + ("dia", "DiaModel"), + ("diffllama", "DiffLlamaModel"), + ("dinat", "DinatModel"), + ("dinov2", "Dinov2Model"), + ("dinov2_with_registers", "Dinov2WithRegistersModel"), + ("dinov3_convnext", "DINOv3ConvNextModel"), + ("dinov3_vit", "DINOv3ViTModel"), + ("distilbert", "DistilBertModel"), + ("doge", "DogeModel"), + ("donut-swin", "DonutSwinModel"), + ("dots1", "Dots1Model"), + ("dpr", "DPRQuestionEncoder"), + ("dpt", "DPTModel"), + ("edgetam", "EdgeTamModel"), + ("edgetam_video", "EdgeTamVideoModel"), + ("edgetam_vision_model", "EdgeTamVisionModel"), + ("efficientformer", "EfficientFormerModel"), + ("efficientloftr", "EfficientLoFTRModel"), + ("efficientnet", "EfficientNetModel"), + ("electra", "ElectraModel"), + ("emu3", "Emu3Model"), + ("encodec", "EncodecModel"), + ("ernie", "ErnieModel"), + ("ernie4_5", "Ernie4_5Model"), + ("ernie4_5_moe", "Ernie4_5_MoeModel"), + ("ernie_m", "ErnieMModel"), + ("esm", "EsmModel"), + ("evolla", "EvollaModel"), + ("exaone4", "Exaone4Model"), + ("falcon", "FalconModel"), + ("falcon_h1", "FalconH1Model"), + ("falcon_mamba", "FalconMambaModel"), + ("fastspeech2_conformer", "FastSpeech2ConformerModel"), + ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"), + ("flaubert", "FlaubertModel"), + ("flava", "FlavaModel"), + ("flex_olmo", "FlexOlmoModel"), + ("florence2", "Florence2Model"), + ("fnet", "FNetModel"), + ("focalnet", "FocalNetModel"), + ("fsmt", "FSMTModel"), + ("funnel", ("FunnelModel", "FunnelBaseModel")), + ("fuyu", "FuyuModel"), + ("gemma", "GemmaModel"), + ("gemma2", "Gemma2Model"), + ("gemma3", "Gemma3Model"), + ("gemma3_text", "Gemma3TextModel"), + ("gemma3n", "Gemma3nModel"), + ("gemma3n_audio", "Gemma3nAudioEncoder"), + ("gemma3n_text", "Gemma3nTextModel"), + ("gemma3n_vision", "TimmWrapperModel"), + ("git", "GitModel"), + ("glm", "GlmModel"), + ("glm4", "Glm4Model"), + ("glm4_moe", "Glm4MoeModel"), + ("glm4v", "Glm4vModel"), + ("glm4v_moe", "Glm4vMoeModel"), + ("glm4v_moe_text", "Glm4vMoeTextModel"), + ("glm4v_text", "Glm4vTextModel"), + ("glpn", "GLPNModel"), + ("got_ocr2", "GotOcr2Model"), + ("gpt-sw3", "GPT2Model"), + ("gpt2", "GPT2Model"), + ("gpt_bigcode", "GPTBigCodeModel"), + ("gpt_neo", "GPTNeoModel"), + ("gpt_neox", "GPTNeoXModel"), + ("gpt_neox_japanese", "GPTNeoXJapaneseModel"), + ("gpt_oss", "GptOssModel"), + ("gptj", "GPTJModel"), + ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), + ("granite", "GraniteModel"), + ("granitemoe", "GraniteMoeModel"), + ("granitemoehybrid", "GraniteMoeHybridModel"), + ("granitemoeshared", "GraniteMoeSharedModel"), + ("graphormer", "GraphormerModel"), + ("grounding-dino", "GroundingDinoModel"), + ("groupvit", "GroupViTModel"), + ("helium", "HeliumModel"), + ("hgnet_v2", "HGNetV2Backbone"), + ("hiera", "HieraModel"), + ("hubert", "HubertModel"), + ("hunyuan_v1_dense", "HunYuanDenseV1Model"), + ("hunyuan_v1_moe", "HunYuanMoEV1Model"), + ("ibert", "IBertModel"), + ("idefics", "IdeficsModel"), + ("idefics2", "Idefics2Model"), + ("idefics3", "Idefics3Model"), + ("idefics3_vision", "Idefics3VisionTransformer"), + ("ijepa", "IJepaModel"), + ("imagegpt", "ImageGPTModel"), + ("informer", "InformerModel"), + ("instructblip", "InstructBlipModel"), + ("instructblipvideo", "InstructBlipVideoModel"), + ("internvl", "InternVLModel"), + ("internvl_vision", "InternVLVisionModel"), + ("jamba", "JambaModel"), + ("janus", "JanusModel"), + ("jetmoe", "JetMoeModel"), + ("jukebox", "JukeboxModel"), + ("kosmos-2", "Kosmos2Model"), + ("kosmos-2.5", "Kosmos2_5Model"), + ("kyutai_speech_to_text", "KyutaiSpeechToTextModel"), + ("layoutlm", "LayoutLMModel"), + ("layoutlmv2", "LayoutLMv2Model"), + ("layoutlmv3", "LayoutLMv3Model"), + ("led", "LEDModel"), + ("levit", "LevitModel"), + ("lfm2", "Lfm2Model"), + ("lfm2_vl", "Lfm2VlModel"), + ("lightglue", "LightGlueForKeypointMatching"), + ("lilt", "LiltModel"), + ("llama", "LlamaModel"), + ("llama4", "Llama4ForConditionalGeneration"), + ("llama4_text", "Llama4TextModel"), + ("llava", "LlavaModel"), + ("llava_next", "LlavaNextModel"), + ("llava_next_video", "LlavaNextVideoModel"), + ("llava_onevision", "LlavaOnevisionModel"), + ("longcat_flash", "LongcatFlashModel"), + ("longformer", "LongformerModel"), + ("longt5", "LongT5Model"), + ("luke", "LukeModel"), + ("lxmert", "LxmertModel"), + ("m2m_100", "M2M100Model"), + ("mamba", "MambaModel"), + ("mamba2", "Mamba2Model"), + ("marian", "MarianModel"), + ("markuplm", "MarkupLMModel"), + ("mask2former", "Mask2FormerModel"), + ("maskformer", "MaskFormerModel"), + ("maskformer-swin", "MaskFormerSwinModel"), + ("mbart", "MBartModel"), + ("mctct", "MCTCTModel"), + ("mega", "MegaModel"), + ("megatron-bert", "MegatronBertModel"), + ("metaclip_2", "MetaClip2Model"), + ("mgp-str", "MgpstrForSceneTextRecognition"), + ("mimi", "MimiModel"), + ("minimax", "MiniMaxModel"), + ("ministral", "MinistralModel"), + ("mistral", "MistralModel"), + ("mistral3", "Mistral3Model"), + ("mixtral", "MixtralModel"), + ("mlcd", "MLCDVisionModel"), + ("mllama", "MllamaModel"), + ("mm-grounding-dino", "MMGroundingDinoModel"), + ("mobilebert", "MobileBertModel"), + ("mobilenet_v1", "MobileNetV1Model"), + ("mobilenet_v2", "MobileNetV2Model"), + ("mobilevit", "MobileViTModel"), + ("mobilevitv2", "MobileViTV2Model"), + ("modernbert", "ModernBertModel"), + ("modernbert-decoder", "ModernBertDecoderModel"), + ("moonshine", "MoonshineModel"), + ("moshi", "MoshiModel"), + ("mpnet", "MPNetModel"), + ("mpt", "MptModel"), + ("mra", "MraModel"), + ("mt5", "MT5Model"), + ("musicgen", "MusicgenModel"), + ("musicgen_melody", "MusicgenMelodyModel"), + ("mvp", "MvpModel"), + ("nat", "NatModel"), + ("nemotron", "NemotronModel"), + ("nezha", "NezhaModel"), + ("nllb-moe", "NllbMoeModel"), + ("nystromformer", "NystromformerModel"), + ("olmo", "OlmoModel"), + ("olmo2", "Olmo2Model"), + ("olmo3", "Olmo3Model"), + ("olmoe", "OlmoeModel"), + ("omdet-turbo", "OmDetTurboForObjectDetection"), + ("oneformer", "OneFormerModel"), + ("open-llama", "OpenLlamaModel"), + ("openai-gpt", "OpenAIGPTModel"), + ("opt", "OPTModel"), + ("ovis2", "Ovis2Model"), + ("owlv2", "Owlv2Model"), + ("owlvit", "OwlViTModel"), + ("paligemma", "PaliGemmaModel"), + ("parakeet_ctc", "ParakeetForCTC"), + ("parakeet_encoder", "ParakeetEncoder"), + ("patchtsmixer", "PatchTSMixerModel"), + ("patchtst", "PatchTSTModel"), + ("pegasus", "PegasusModel"), + ("pegasus_x", "PegasusXModel"), + ("perceiver", "PerceiverModel"), + ("perception_encoder", "PerceptionEncoder"), + ("perception_lm", "PerceptionLMModel"), + ("persimmon", "PersimmonModel"), + ("phi", "PhiModel"), + ("phi3", "Phi3Model"), + ("phi4_multimodal", "Phi4MultimodalModel"), + ("phimoe", "PhimoeModel"), + ("pixtral", "PixtralVisionModel"), + ("plbart", "PLBartModel"), + ("poolformer", "PoolFormerModel"), + ("prophetnet", "ProphetNetModel"), + ("pvt", "PvtModel"), + ("pvt_v2", "PvtV2Model"), + ("qdqbert", "QDQBertModel"), + ("qwen2", "Qwen2Model"), + ("qwen2_5_vl", "Qwen2_5_VLModel"), + ("qwen2_5_vl_text", "Qwen2_5_VLTextModel"), + ("qwen2_audio_encoder", "Qwen2AudioEncoder"), + ("qwen2_moe", "Qwen2MoeModel"), + ("qwen2_vl", "Qwen2VLModel"), + ("qwen2_vl_text", "Qwen2VLTextModel"), + ("qwen3", "Qwen3Model"), + ("qwen3_moe", "Qwen3MoeModel"), + ("qwen3_next", "Qwen3NextModel"), + ("qwen3_vl", "Qwen3VLModel"), + ("qwen3_vl_moe", "Qwen3VLMoeModel"), + ("qwen3_vl_moe_text", "Qwen3VLMoeTextModel"), + ("qwen3_vl_text", "Qwen3VLTextModel"), + ("recurrent_gemma", "RecurrentGemmaModel"), + ("reformer", "ReformerModel"), + ("regnet", "RegNetModel"), + ("rembert", "RemBertModel"), + ("resnet", "ResNetModel"), + ("retribert", "RetriBertModel"), + ("roberta", "RobertaModel"), + ("roberta-prelayernorm", "RobertaPreLayerNormModel"), + ("roc_bert", "RoCBertModel"), + ("roformer", "RoFormerModel"), + ("rt_detr", "RTDetrModel"), + ("rt_detr_v2", "RTDetrV2Model"), + ("rwkv", "RwkvModel"), + ("sam", "SamModel"), + ("sam2", "Sam2Model"), + ("sam2_hiera_det_model", "Sam2HieraDetModel"), + ("sam2_video", "Sam2VideoModel"), + ("sam2_vision_model", "Sam2VisionModel"), + ("sam_hq", "SamHQModel"), + ("sam_hq_vision_model", "SamHQVisionModel"), + ("sam_vision_model", "SamVisionModel"), + ("seamless_m4t", "SeamlessM4TModel"), + ("seamless_m4t_v2", "SeamlessM4Tv2Model"), + ("seed_oss", "SeedOssModel"), + ("segformer", "SegformerModel"), + ("seggpt", "SegGptModel"), + ("sew", "SEWModel"), + ("sew-d", "SEWDModel"), + ("siglip", "SiglipModel"), + ("siglip2", "Siglip2Model"), + ("siglip2_vision_model", "Siglip2VisionModel"), + ("siglip_vision_model", "SiglipVisionModel"), + ("smollm3", "SmolLM3Model"), + ("smolvlm", "SmolVLMModel"), + ("smolvlm_vision", "SmolVLMVisionTransformer"), + ("speech_to_text", "Speech2TextModel"), + ("speecht5", "SpeechT5Model"), + ("splinter", "SplinterModel"), + ("squeezebert", "SqueezeBertModel"), + ("stablelm", "StableLmModel"), + ("starcoder2", "Starcoder2Model"), + ("swiftformer", "SwiftFormerModel"), + ("swin", "SwinModel"), + ("swin2sr", "Swin2SRModel"), + ("swinv2", "Swinv2Model"), + ("switch_transformers", "SwitchTransformersModel"), + ("t5", "T5Model"), + ("t5gemma", "T5GemmaModel"), + ("table-transformer", "TableTransformerModel"), + ("tapas", "TapasModel"), + ("textnet", "TextNetModel"), + ("time_series_transformer", "TimeSeriesTransformerModel"), + ("timesfm", "TimesFmModel"), + ("timesformer", "TimesformerModel"), + ("timm_backbone", "TimmBackbone"), + ("timm_wrapper", "TimmWrapperModel"), + ("trajectory_transformer", "TrajectoryTransformerModel"), + ("transfo-xl", "TransfoXLModel"), + ("tvlt", "TvltModel"), + ("tvp", "TvpModel"), + ("udop", "UdopModel"), + ("umt5", "UMT5Model"), + ("unispeech", "UniSpeechModel"), + ("unispeech-sat", "UniSpeechSatModel"), + ("univnet", "UnivNetModel"), + ("van", "VanModel"), + ("vaultgemma", "VaultGemmaModel"), + ("video_llava", "VideoLlavaModel"), + ("videomae", "VideoMAEModel"), + ("vilt", "ViltModel"), + ("vipllava", "VipLlavaModel"), + ("vision-text-dual-encoder", "VisionTextDualEncoderModel"), + ("visual_bert", "VisualBertModel"), + ("vit", "ViTModel"), + ("vit_hybrid", "ViTHybridModel"), + ("vit_mae", "ViTMAEModel"), + ("vit_msn", "ViTMSNModel"), + ("vitdet", "VitDetModel"), + ("vits", "VitsModel"), + ("vivit", "VivitModel"), + ("vjepa2", "VJEPA2Model"), + ("voxtral", "VoxtralForConditionalGeneration"), + ("voxtral_encoder", "VoxtralEncoder"), + ("wav2vec2", "Wav2Vec2Model"), + ("wav2vec2-bert", "Wav2Vec2BertModel"), + ("wav2vec2-conformer", "Wav2Vec2ConformerModel"), + ("wavlm", "WavLMModel"), + ("whisper", "WhisperModel"), + ("xclip", "XCLIPModel"), + ("xcodec", "XcodecModel"), + ("xglm", "XGLMModel"), + ("xlm", "XLMModel"), + ("xlm-prophetnet", "XLMProphetNetModel"), + ("xlm-roberta", "XLMRobertaModel"), + ("xlm-roberta-xl", "XLMRobertaXLModel"), + ("xlnet", "XLNetModel"), + ("xlstm", "xLSTMModel"), + ("xmod", "XmodModel"), + ("yolos", "YolosModel"), + ("yoso", "YosoModel"), + ("zamba", "ZambaModel"), + ("zamba2", "Zamba2Model"), + ] +) + +MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( + [ + # Model for pre-training mapping + ("albert", "AlbertForPreTraining"), + ("bart", "BartForConditionalGeneration"), + ("bert", "BertForPreTraining"), + ("big_bird", "BigBirdForPreTraining"), + ("bloom", "BloomForCausalLM"), + ("camembert", "CamembertForMaskedLM"), + ("colpali", "ColPaliForRetrieval"), + ("colqwen2", "ColQwen2ForRetrieval"), + ("ctrl", "CTRLLMHeadModel"), + ("data2vec-text", "Data2VecTextForMaskedLM"), + ("deberta", "DebertaForMaskedLM"), + ("deberta-v2", "DebertaV2ForMaskedLM"), + ("distilbert", "DistilBertForMaskedLM"), + ("electra", "ElectraForPreTraining"), + ("ernie", "ErnieForPreTraining"), + ("evolla", "EvollaForProteinText2Text"), + ("exaone4", "Exaone4ForCausalLM"), + ("falcon_mamba", "FalconMambaForCausalLM"), + ("flaubert", "FlaubertWithLMHeadModel"), + ("flava", "FlavaForPreTraining"), + ("florence2", "Florence2ForConditionalGeneration"), + ("fnet", "FNetForPreTraining"), + ("fsmt", "FSMTForConditionalGeneration"), + ("funnel", "FunnelForPreTraining"), + ("gemma3", "Gemma3ForConditionalGeneration"), + ("gpt-sw3", "GPT2LMHeadModel"), + ("gpt2", "GPT2LMHeadModel"), + ("gpt_bigcode", "GPTBigCodeForCausalLM"), + ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), + ("hiera", "HieraForPreTraining"), + ("ibert", "IBertForMaskedLM"), + ("idefics", "IdeficsForVisionText2Text"), + ("idefics2", "Idefics2ForConditionalGeneration"), + ("idefics3", "Idefics3ForConditionalGeneration"), + ("janus", "JanusForConditionalGeneration"), + ("layoutlm", "LayoutLMForMaskedLM"), + ("llava", "LlavaForConditionalGeneration"), + ("llava_next", "LlavaNextForConditionalGeneration"), + ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), + ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), + ("longformer", "LongformerForMaskedLM"), + ("luke", "LukeForMaskedLM"), + ("lxmert", "LxmertForPreTraining"), + ("mamba", "MambaForCausalLM"), + ("mamba2", "Mamba2ForCausalLM"), + ("mega", "MegaForMaskedLM"), + ("megatron-bert", "MegatronBertForPreTraining"), + ("mistral3", "Mistral3ForConditionalGeneration"), + ("mllama", "MllamaForConditionalGeneration"), + ("mobilebert", "MobileBertForPreTraining"), + ("mpnet", "MPNetForMaskedLM"), + ("mpt", "MptForCausalLM"), + ("mra", "MraForMaskedLM"), + ("mvp", "MvpForConditionalGeneration"), + ("nezha", "NezhaForPreTraining"), + ("nllb-moe", "NllbMoeForConditionalGeneration"), + ("openai-gpt", "OpenAIGPTLMHeadModel"), + ("paligemma", "PaliGemmaForConditionalGeneration"), + ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), + ("retribert", "RetriBertModel"), + ("roberta", "RobertaForMaskedLM"), + ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), + ("roc_bert", "RoCBertForPreTraining"), + ("rwkv", "RwkvForCausalLM"), + ("splinter", "SplinterForPreTraining"), + ("squeezebert", "SqueezeBertForMaskedLM"), + ("switch_transformers", "SwitchTransformersForConditionalGeneration"), + ("t5", "T5ForConditionalGeneration"), + ("t5gemma", "T5GemmaForConditionalGeneration"), + ("tapas", "TapasForMaskedLM"), + ("transfo-xl", "TransfoXLLMHeadModel"), + ("tvlt", "TvltForPreTraining"), + ("unispeech", "UniSpeechForPreTraining"), + ("unispeech-sat", "UniSpeechSatForPreTraining"), + ("video_llava", "VideoLlavaForConditionalGeneration"), + ("videomae", "VideoMAEForPreTraining"), + ("vipllava", "VipLlavaForConditionalGeneration"), + ("visual_bert", "VisualBertForPreTraining"), + ("vit_mae", "ViTMAEForPreTraining"), + ("voxtral", "VoxtralForConditionalGeneration"), + ("wav2vec2", "Wav2Vec2ForPreTraining"), + ("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"), + ("xlm", "XLMWithLMHeadModel"), + ("xlm-roberta", "XLMRobertaForMaskedLM"), + ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), + ("xlnet", "XLNetLMHeadModel"), + ("xlstm", "xLSTMForCausalLM"), + ("xmod", "XmodForMaskedLM"), + ] +) + +MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( + [ + # Model with LM heads mapping + ("albert", "AlbertForMaskedLM"), + ("bart", "BartForConditionalGeneration"), + ("bert", "BertForMaskedLM"), + ("big_bird", "BigBirdForMaskedLM"), + ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), + ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), + ("bloom", "BloomForCausalLM"), + ("camembert", "CamembertForMaskedLM"), + ("codegen", "CodeGenForCausalLM"), + ("convbert", "ConvBertForMaskedLM"), + ("cpmant", "CpmAntForCausalLM"), + ("ctrl", "CTRLLMHeadModel"), + ("data2vec-text", "Data2VecTextForMaskedLM"), + ("deberta", "DebertaForMaskedLM"), + ("deberta-v2", "DebertaV2ForMaskedLM"), + ("dia", "DiaForConditionalGeneration"), + ("distilbert", "DistilBertForMaskedLM"), + ("electra", "ElectraForMaskedLM"), + ("encoder-decoder", "EncoderDecoderModel"), + ("ernie", "ErnieForMaskedLM"), + ("esm", "EsmForMaskedLM"), + ("exaone4", "Exaone4ForCausalLM"), + ("falcon_mamba", "FalconMambaForCausalLM"), + ("flaubert", "FlaubertWithLMHeadModel"), + ("fnet", "FNetForMaskedLM"), + ("fsmt", "FSMTForConditionalGeneration"), + ("funnel", "FunnelForMaskedLM"), + ("git", "GitForCausalLM"), + ("gpt-sw3", "GPT2LMHeadModel"), + ("gpt2", "GPT2LMHeadModel"), + ("gpt_bigcode", "GPTBigCodeForCausalLM"), + ("gpt_neo", "GPTNeoForCausalLM"), + ("gpt_neox", "GPTNeoXForCausalLM"), + ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), + ("gptj", "GPTJForCausalLM"), + ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), + ("ibert", "IBertForMaskedLM"), + ("layoutlm", "LayoutLMForMaskedLM"), + ("led", "LEDForConditionalGeneration"), + ("longformer", "LongformerForMaskedLM"), + ("longt5", "LongT5ForConditionalGeneration"), + ("luke", "LukeForMaskedLM"), + ("m2m_100", "M2M100ForConditionalGeneration"), + ("mamba", "MambaForCausalLM"), + ("mamba2", "Mamba2ForCausalLM"), + ("marian", "MarianMTModel"), + ("mega", "MegaForMaskedLM"), + ("megatron-bert", "MegatronBertForCausalLM"), + ("mobilebert", "MobileBertForMaskedLM"), + ("moonshine", "MoonshineForConditionalGeneration"), + ("mpnet", "MPNetForMaskedLM"), + ("mpt", "MptForCausalLM"), + ("mra", "MraForMaskedLM"), + ("mvp", "MvpForConditionalGeneration"), + ("nezha", "NezhaForMaskedLM"), + ("nllb-moe", "NllbMoeForConditionalGeneration"), + ("nystromformer", "NystromformerForMaskedLM"), + ("openai-gpt", "OpenAIGPTLMHeadModel"), + ("pegasus_x", "PegasusXForConditionalGeneration"), + ("plbart", "PLBartForConditionalGeneration"), + ("pop2piano", "Pop2PianoForConditionalGeneration"), + ("qdqbert", "QDQBertForMaskedLM"), + ("reformer", "ReformerModelWithLMHead"), + ("rembert", "RemBertForMaskedLM"), + ("roberta", "RobertaForMaskedLM"), + ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), + ("roc_bert", "RoCBertForMaskedLM"), + ("roformer", "RoFormerForMaskedLM"), + ("rwkv", "RwkvForCausalLM"), + ("speech_to_text", "Speech2TextForConditionalGeneration"), + ("squeezebert", "SqueezeBertForMaskedLM"), + ("switch_transformers", "SwitchTransformersForConditionalGeneration"), + ("t5", "T5ForConditionalGeneration"), + ("t5gemma", "T5GemmaForConditionalGeneration"), + ("tapas", "TapasForMaskedLM"), + ("transfo-xl", "TransfoXLLMHeadModel"), + ("wav2vec2", "Wav2Vec2ForMaskedLM"), + ("whisper", "WhisperForConditionalGeneration"), + ("xlm", "XLMWithLMHeadModel"), + ("xlm-roberta", "XLMRobertaForMaskedLM"), + ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), + ("xlnet", "XLNetLMHeadModel"), + ("xmod", "XmodForMaskedLM"), + ("yoso", "YosoForMaskedLM"), + ] +) + +MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( + [ + # Model for Causal LM mapping + ("apertus", "ApertusForCausalLM"), + ("arcee", "ArceeForCausalLM"), + ("aria_text", "AriaTextForCausalLM"), + ("bamba", "BambaForCausalLM"), + ("bart", "BartForCausalLM"), + ("bert", "BertLMHeadModel"), + ("bert-generation", "BertGenerationDecoder"), + ("big_bird", "BigBirdForCausalLM"), + ("bigbird_pegasus", "BigBirdPegasusForCausalLM"), + ("biogpt", "BioGptForCausalLM"), + ("bitnet", "BitNetForCausalLM"), + ("blenderbot", "BlenderbotForCausalLM"), + ("blenderbot-small", "BlenderbotSmallForCausalLM"), + ("bloom", "BloomForCausalLM"), + ("blt", "BltForCausalLM"), + ("camembert", "CamembertForCausalLM"), + ("code_llama", "LlamaForCausalLM"), + ("codegen", "CodeGenForCausalLM"), + ("cohere", "CohereForCausalLM"), + ("cohere2", "Cohere2ForCausalLM"), + ("cpmant", "CpmAntForCausalLM"), + ("ctrl", "CTRLLMHeadModel"), + ("data2vec-text", "Data2VecTextForCausalLM"), + ("dbrx", "DbrxForCausalLM"), + ("deepseek_v2", "DeepseekV2ForCausalLM"), + ("deepseek_v3", "DeepseekV3ForCausalLM"), + ("diffllama", "DiffLlamaForCausalLM"), + ("doge", "DogeForCausalLM"), + ("dots1", "Dots1ForCausalLM"), + ("electra", "ElectraForCausalLM"), + ("emu3", "Emu3ForCausalLM"), + ("ernie", "ErnieForCausalLM"), + ("ernie4_5", "Ernie4_5ForCausalLM"), + ("ernie4_5_moe", "Ernie4_5_MoeForCausalLM"), + ("exaone4", "Exaone4ForCausalLM"), + ("falcon", "FalconForCausalLM"), + ("falcon_h1", "FalconH1ForCausalLM"), + ("falcon_mamba", "FalconMambaForCausalLM"), + ("flex_olmo", "FlexOlmoForCausalLM"), + ("fuyu", "FuyuForCausalLM"), + ("gemma", "GemmaForCausalLM"), + ("gemma2", "Gemma2ForCausalLM"), + ("gemma3", "Gemma3ForConditionalGeneration"), + ("gemma3_text", "Gemma3ForCausalLM"), + ("gemma3n", "Gemma3nForConditionalGeneration"), + ("gemma3n_text", "Gemma3nForCausalLM"), + ("git", "GitForCausalLM"), + ("glm", "GlmForCausalLM"), + ("glm4", "Glm4ForCausalLM"), + ("glm4_moe", "Glm4MoeForCausalLM"), + ("got_ocr2", "GotOcr2ForConditionalGeneration"), + ("gpt-sw3", "GPT2LMHeadModel"), + ("gpt2", "GPT2LMHeadModel"), + ("gpt_bigcode", "GPTBigCodeForCausalLM"), + ("gpt_neo", "GPTNeoForCausalLM"), + ("gpt_neox", "GPTNeoXForCausalLM"), + ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), + ("gpt_oss", "GptOssForCausalLM"), + ("gptj", "GPTJForCausalLM"), + ("granite", "GraniteForCausalLM"), + ("granitemoe", "GraniteMoeForCausalLM"), + ("granitemoehybrid", "GraniteMoeHybridForCausalLM"), + ("granitemoeshared", "GraniteMoeSharedForCausalLM"), + ("helium", "HeliumForCausalLM"), + ("hunyuan_v1_dense", "HunYuanDenseV1ForCausalLM"), + ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"), + ("jamba", "JambaForCausalLM"), + ("jetmoe", "JetMoeForCausalLM"), + ("lfm2", "Lfm2ForCausalLM"), + ("llama", "LlamaForCausalLM"), + ("llama4", "Llama4ForCausalLM"), + ("llama4_text", "Llama4ForCausalLM"), + ("longcat_flash", "LongcatFlashForCausalLM"), + ("mamba", "MambaForCausalLM"), + ("mamba2", "Mamba2ForCausalLM"), + ("marian", "MarianForCausalLM"), + ("mbart", "MBartForCausalLM"), + ("mega", "MegaForCausalLM"), + ("megatron-bert", "MegatronBertForCausalLM"), + ("minimax", "MiniMaxForCausalLM"), + ("ministral", "MinistralForCausalLM"), + ("mistral", "MistralForCausalLM"), + ("mixtral", "MixtralForCausalLM"), + ("mllama", "MllamaForCausalLM"), + ("modernbert-decoder", "ModernBertDecoderForCausalLM"), + ("moshi", "MoshiForCausalLM"), + ("mpt", "MptForCausalLM"), + ("musicgen", "MusicgenForCausalLM"), + ("musicgen_melody", "MusicgenMelodyForCausalLM"), + ("mvp", "MvpForCausalLM"), + ("nemotron", "NemotronForCausalLM"), + ("olmo", "OlmoForCausalLM"), + ("olmo2", "Olmo2ForCausalLM"), + ("olmo3", "Olmo3ForCausalLM"), + ("olmoe", "OlmoeForCausalLM"), + ("open-llama", "OpenLlamaForCausalLM"), + ("openai-gpt", "OpenAIGPTLMHeadModel"), + ("opt", "OPTForCausalLM"), + ("pegasus", "PegasusForCausalLM"), + ("persimmon", "PersimmonForCausalLM"), + ("phi", "PhiForCausalLM"), + ("phi3", "Phi3ForCausalLM"), + ("phi4_multimodal", "Phi4MultimodalForCausalLM"), + ("phimoe", "PhimoeForCausalLM"), + ("plbart", "PLBartForCausalLM"), + ("prophetnet", "ProphetNetForCausalLM"), + ("qdqbert", "QDQBertLMHeadModel"), + ("qwen2", "Qwen2ForCausalLM"), + ("qwen2_moe", "Qwen2MoeForCausalLM"), + ("qwen3", "Qwen3ForCausalLM"), + ("qwen3_moe", "Qwen3MoeForCausalLM"), + ("qwen3_next", "Qwen3NextForCausalLM"), + ("recurrent_gemma", "RecurrentGemmaForCausalLM"), + ("reformer", "ReformerModelWithLMHead"), + ("rembert", "RemBertForCausalLM"), + ("roberta", "RobertaForCausalLM"), + ("roberta-prelayernorm", "RobertaPreLayerNormForCausalLM"), + ("roc_bert", "RoCBertForCausalLM"), + ("roformer", "RoFormerForCausalLM"), + ("rwkv", "RwkvForCausalLM"), + ("seed_oss", "SeedOssForCausalLM"), + ("smollm3", "SmolLM3ForCausalLM"), + ("speech_to_text_2", "Speech2Text2ForCausalLM"), + ("stablelm", "StableLmForCausalLM"), + ("starcoder2", "Starcoder2ForCausalLM"), + ("transfo-xl", "TransfoXLLMHeadModel"), + ("trocr", "TrOCRForCausalLM"), + ("vaultgemma", "VaultGemmaForCausalLM"), + ("whisper", "WhisperForCausalLM"), + ("xglm", "XGLMForCausalLM"), + ("xlm", "XLMWithLMHeadModel"), + ("xlm-prophetnet", "XLMProphetNetForCausalLM"), + ("xlm-roberta", "XLMRobertaForCausalLM"), + ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"), + ("xlnet", "XLNetLMHeadModel"), + ("xlstm", "xLSTMForCausalLM"), + ("xmod", "XmodForCausalLM"), + ("zamba", "ZambaForCausalLM"), + ("zamba2", "Zamba2ForCausalLM"), + ] +) + +MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict( + [ + # Model for Image mapping + ("aimv2_vision_model", "Aimv2VisionModel"), + ("beit", "BeitModel"), + ("bit", "BitModel"), + ("cohere2_vision", "Cohere2VisionModel"), + ("conditional_detr", "ConditionalDetrModel"), + ("convnext", "ConvNextModel"), + ("convnextv2", "ConvNextV2Model"), + ("dab-detr", "DabDetrModel"), + ("data2vec-vision", "Data2VecVisionModel"), + ("deformable_detr", "DeformableDetrModel"), + ("deit", "DeiTModel"), + ("depth_pro", "DepthProModel"), + ("deta", "DetaModel"), + ("detr", "DetrModel"), + ("dinat", "DinatModel"), + ("dinov2", "Dinov2Model"), + ("dinov2_with_registers", "Dinov2WithRegistersModel"), + ("dinov3_convnext", "DINOv3ConvNextModel"), + ("dinov3_vit", "DINOv3ViTModel"), + ("dpt", "DPTModel"), + ("efficientformer", "EfficientFormerModel"), + ("efficientnet", "EfficientNetModel"), + ("focalnet", "FocalNetModel"), + ("glpn", "GLPNModel"), + ("hiera", "HieraModel"), + ("ijepa", "IJepaModel"), + ("imagegpt", "ImageGPTModel"), + ("levit", "LevitModel"), + ("llama4", "Llama4VisionModel"), + ("mlcd", "MLCDVisionModel"), + ("mllama", "MllamaVisionModel"), + ("mobilenet_v1", "MobileNetV1Model"), + ("mobilenet_v2", "MobileNetV2Model"), + ("mobilevit", "MobileViTModel"), + ("mobilevitv2", "MobileViTV2Model"), + ("nat", "NatModel"), + ("poolformer", "PoolFormerModel"), + ("pvt", "PvtModel"), + ("regnet", "RegNetModel"), + ("resnet", "ResNetModel"), + ("segformer", "SegformerModel"), + ("siglip_vision_model", "SiglipVisionModel"), + ("swiftformer", "SwiftFormerModel"), + ("swin", "SwinModel"), + ("swin2sr", "Swin2SRModel"), + ("swinv2", "Swinv2Model"), + ("table-transformer", "TableTransformerModel"), + ("timesformer", "TimesformerModel"), + ("timm_backbone", "TimmBackbone"), + ("timm_wrapper", "TimmWrapperModel"), + ("van", "VanModel"), + ("videomae", "VideoMAEModel"), + ("vit", "ViTModel"), + ("vit_hybrid", "ViTHybridModel"), + ("vit_mae", "ViTMAEModel"), + ("vit_msn", "ViTMSNModel"), + ("vitdet", "VitDetModel"), + ("vivit", "VivitModel"), + ("yolos", "YolosModel"), + ] +) + +MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( + [ + ("deit", "DeiTForMaskedImageModeling"), + ("focalnet", "FocalNetForMaskedImageModeling"), + ("swin", "SwinForMaskedImageModeling"), + ("swinv2", "Swinv2ForMaskedImageModeling"), + ("vit", "ViTForMaskedImageModeling"), + ] +) + + +MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( + # Model for Causal Image Modeling mapping + [ + ("imagegpt", "ImageGPTForCausalImageModeling"), + ] +) + +MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Image Classification mapping + ("beit", "BeitForImageClassification"), + ("bit", "BitForImageClassification"), + ("clip", "CLIPForImageClassification"), + ("convnext", "ConvNextForImageClassification"), + ("convnextv2", "ConvNextV2ForImageClassification"), + ("cvt", "CvtForImageClassification"), + ("data2vec-vision", "Data2VecVisionForImageClassification"), + ( + "deit", + ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher"), + ), + ("dinat", "DinatForImageClassification"), + ("dinov2", "Dinov2ForImageClassification"), + ("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"), + ("donut-swin", "DonutSwinForImageClassification"), + ( + "efficientformer", + ( + "EfficientFormerForImageClassification", + "EfficientFormerForImageClassificationWithTeacher", + ), + ), + ("efficientnet", "EfficientNetForImageClassification"), + ("focalnet", "FocalNetForImageClassification"), + ("hgnet_v2", "HGNetV2ForImageClassification"), + ("hiera", "HieraForImageClassification"), + ("ijepa", "IJepaForImageClassification"), + ("imagegpt", "ImageGPTForImageClassification"), + ( + "levit", + ("LevitForImageClassification", "LevitForImageClassificationWithTeacher"), + ), + ("metaclip_2", "MetaClip2ForImageClassification"), + ("mobilenet_v1", "MobileNetV1ForImageClassification"), + ("mobilenet_v2", "MobileNetV2ForImageClassification"), + ("mobilevit", "MobileViTForImageClassification"), + ("mobilevitv2", "MobileViTV2ForImageClassification"), + ("nat", "NatForImageClassification"), + ( + "perceiver", + ( + "PerceiverForImageClassificationLearned", + "PerceiverForImageClassificationFourier", + "PerceiverForImageClassificationConvProcessing", + ), + ), + ("poolformer", "PoolFormerForImageClassification"), + ("pvt", "PvtForImageClassification"), + ("pvt_v2", "PvtV2ForImageClassification"), + ("regnet", "RegNetForImageClassification"), + ("resnet", "ResNetForImageClassification"), + ("segformer", "SegformerForImageClassification"), + ("shieldgemma2", "ShieldGemma2ForImageClassification"), + ("siglip", "SiglipForImageClassification"), + ("siglip2", "Siglip2ForImageClassification"), + ("swiftformer", "SwiftFormerForImageClassification"), + ("swin", "SwinForImageClassification"), + ("swinv2", "Swinv2ForImageClassification"), + ("textnet", "TextNetForImageClassification"), + ("timm_wrapper", "TimmWrapperForImageClassification"), + ("van", "VanForImageClassification"), + ("vit", "ViTForImageClassification"), + ("vit_hybrid", "ViTHybridForImageClassification"), + ("vit_msn", "ViTMSNForImageClassification"), + ] +) + +MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict( + [ + # Do not add new models here, this class will be deprecated in the future. + # Model for Image Segmentation mapping + ("detr", "DetrForSegmentation"), + ] +) + +MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Semantic Segmentation mapping + ("beit", "BeitForSemanticSegmentation"), + ("data2vec-vision", "Data2VecVisionForSemanticSegmentation"), + ("dpt", "DPTForSemanticSegmentation"), + ("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"), + ("mobilevit", "MobileViTForSemanticSegmentation"), + ("mobilevitv2", "MobileViTV2ForSemanticSegmentation"), + ("segformer", "SegformerForSemanticSegmentation"), + ("upernet", "UperNetForSemanticSegmentation"), + ] +) + +MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Instance Segmentation mapping + # MaskFormerForInstanceSegmentation can be removed from this mapping in v5 + ("maskformer", "MaskFormerForInstanceSegmentation"), + ] +) + +MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Universal Segmentation mapping + ("detr", "DetrForSegmentation"), + ("eomt", "EomtForUniversalSegmentation"), + ("mask2former", "Mask2FormerForUniversalSegmentation"), + ("maskformer", "MaskFormerForInstanceSegmentation"), + ("oneformer", "OneFormerForUniversalSegmentation"), + ] +) + +MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + ("timesformer", "TimesformerForVideoClassification"), + ("videomae", "VideoMAEForVideoClassification"), + ("vivit", "VivitForVideoClassification"), + ("vjepa2", "VJEPA2ForVideoClassification"), + ] +) + +MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( + [ + ("blip", "BlipForConditionalGeneration"), + ("blip-2", "Blip2ForConditionalGeneration"), + ("chameleon", "ChameleonForConditionalGeneration"), + ("git", "GitForCausalLM"), + ("idefics2", "Idefics2ForConditionalGeneration"), + ("idefics3", "Idefics3ForConditionalGeneration"), + ("instructblip", "InstructBlipForConditionalGeneration"), + ("instructblipvideo", "InstructBlipVideoForConditionalGeneration"), + ("kosmos-2", "Kosmos2ForConditionalGeneration"), + ("kosmos-2.5", "Kosmos2_5ForConditionalGeneration"), + ("llava", "LlavaForConditionalGeneration"), + ("llava_next", "LlavaNextForConditionalGeneration"), + ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), + ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), + ("mistral3", "Mistral3ForConditionalGeneration"), + ("mllama", "MllamaForConditionalGeneration"), + ("ovis2", "Ovis2ForConditionalGeneration"), + ("paligemma", "PaliGemmaForConditionalGeneration"), + ("pix2struct", "Pix2StructForConditionalGeneration"), + ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), + ("qwen2_vl", "Qwen2VLForConditionalGeneration"), + ("qwen3_vl", "Qwen3VLForConditionalGeneration"), + ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"), + ("video_llava", "VideoLlavaForConditionalGeneration"), + ("vipllava", "VipLlavaForConditionalGeneration"), + ("vision-encoder-decoder", "VisionEncoderDecoderModel"), + ] +) + +MODEL_FOR_RETRIEVAL_MAPPING_NAMES = OrderedDict( + [ + ("colpali", "ColPaliForRetrieval"), + ] +) + +MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict( + [ + ("aria", "AriaForConditionalGeneration"), + ("aya_vision", "AyaVisionForConditionalGeneration"), + ("blip", "BlipForConditionalGeneration"), + ("blip-2", "Blip2ForConditionalGeneration"), + ("chameleon", "ChameleonForConditionalGeneration"), + ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), + ("deepseek_vl", "DeepseekVLForConditionalGeneration"), + ("deepseek_vl_hybrid", "DeepseekVLHybridForConditionalGeneration"), + ("emu3", "Emu3ForConditionalGeneration"), + ("evolla", "EvollaForProteinText2Text"), + ("florence2", "Florence2ForConditionalGeneration"), + ("fuyu", "FuyuForCausalLM"), + ("gemma3", "Gemma3ForConditionalGeneration"), + ("gemma3n", "Gemma3nForConditionalGeneration"), + ("git", "GitForCausalLM"), + ("glm4v", "Glm4vForConditionalGeneration"), + ("glm4v_moe", "Glm4vMoeForConditionalGeneration"), + ("got_ocr2", "GotOcr2ForConditionalGeneration"), + ("idefics", "IdeficsForVisionText2Text"), + ("idefics2", "Idefics2ForConditionalGeneration"), + ("idefics3", "Idefics3ForConditionalGeneration"), + ("instructblip", "InstructBlipForConditionalGeneration"), + ("internvl", "InternVLForConditionalGeneration"), + ("janus", "JanusForConditionalGeneration"), + ("kosmos-2", "Kosmos2ForConditionalGeneration"), + ("kosmos-2.5", "Kosmos2_5ForConditionalGeneration"), + ("lfm2_vl", "Lfm2VlForConditionalGeneration"), + ("llama4", "Llama4ForConditionalGeneration"), + ("llava", "LlavaForConditionalGeneration"), + ("llava_next", "LlavaNextForConditionalGeneration"), + ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), + ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), + ("mistral3", "Mistral3ForConditionalGeneration"), + ("mllama", "MllamaForConditionalGeneration"), + ("ovis2", "Ovis2ForConditionalGeneration"), + ("paligemma", "PaliGemmaForConditionalGeneration"), + ("perception_lm", "PerceptionLMForConditionalGeneration"), + ("pix2struct", "Pix2StructForConditionalGeneration"), + ("pixtral", "LlavaForConditionalGeneration"), + ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), + ("qwen2_vl", "Qwen2VLForConditionalGeneration"), + ("qwen3_vl", "Qwen3VLForConditionalGeneration"), + ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"), + ("shieldgemma2", "Gemma3ForConditionalGeneration"), + ("smolvlm", "SmolVLMForConditionalGeneration"), + ("udop", "UdopForConditionalGeneration"), + ("vipllava", "VipLlavaForConditionalGeneration"), + ("vision-encoder-decoder", "VisionEncoderDecoderModel"), + ] +) + +MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( + [ + # Model for Masked LM mapping + ("albert", "AlbertForMaskedLM"), + ("bart", "BartForConditionalGeneration"), + ("bert", "BertForMaskedLM"), + ("big_bird", "BigBirdForMaskedLM"), + ("camembert", "CamembertForMaskedLM"), + ("convbert", "ConvBertForMaskedLM"), + ("data2vec-text", "Data2VecTextForMaskedLM"), + ("deberta", "DebertaForMaskedLM"), + ("deberta-v2", "DebertaV2ForMaskedLM"), + ("distilbert", "DistilBertForMaskedLM"), + ("electra", "ElectraForMaskedLM"), + ("ernie", "ErnieForMaskedLM"), + ("esm", "EsmForMaskedLM"), + ("flaubert", "FlaubertWithLMHeadModel"), + ("fnet", "FNetForMaskedLM"), + ("funnel", "FunnelForMaskedLM"), + ("ibert", "IBertForMaskedLM"), + ("layoutlm", "LayoutLMForMaskedLM"), + ("longformer", "LongformerForMaskedLM"), + ("luke", "LukeForMaskedLM"), + ("mbart", "MBartForConditionalGeneration"), + ("mega", "MegaForMaskedLM"), + ("megatron-bert", "MegatronBertForMaskedLM"), + ("mobilebert", "MobileBertForMaskedLM"), + ("modernbert", "ModernBertForMaskedLM"), + ("mpnet", "MPNetForMaskedLM"), + ("mra", "MraForMaskedLM"), + ("mvp", "MvpForConditionalGeneration"), + ("nezha", "NezhaForMaskedLM"), + ("nystromformer", "NystromformerForMaskedLM"), + ("perceiver", "PerceiverForMaskedLM"), + ("qdqbert", "QDQBertForMaskedLM"), + ("reformer", "ReformerForMaskedLM"), + ("rembert", "RemBertForMaskedLM"), + ("roberta", "RobertaForMaskedLM"), + ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), + ("roc_bert", "RoCBertForMaskedLM"), + ("roformer", "RoFormerForMaskedLM"), + ("squeezebert", "SqueezeBertForMaskedLM"), + ("tapas", "TapasForMaskedLM"), + ("wav2vec2", "Wav2Vec2ForMaskedLM"), + ("xlm", "XLMWithLMHeadModel"), + ("xlm-roberta", "XLMRobertaForMaskedLM"), + ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), + ("xmod", "XmodForMaskedLM"), + ("yoso", "YosoForMaskedLM"), + ] +) + +MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( + [ + # Model for Object Detection mapping + ("conditional_detr", "ConditionalDetrForObjectDetection"), + ("d_fine", "DFineForObjectDetection"), + ("dab-detr", "DabDetrForObjectDetection"), + ("deformable_detr", "DeformableDetrForObjectDetection"), + ("deta", "DetaForObjectDetection"), + ("detr", "DetrForObjectDetection"), + ("rt_detr", "RTDetrForObjectDetection"), + ("rt_detr_v2", "RTDetrV2ForObjectDetection"), + ("table-transformer", "TableTransformerForObjectDetection"), + ("yolos", "YolosForObjectDetection"), + ] +) + +MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( + [ + # Model for Zero Shot Object Detection mapping + ("grounding-dino", "GroundingDinoForObjectDetection"), + ("mm-grounding-dino", "MMGroundingDinoForObjectDetection"), + ("omdet-turbo", "OmDetTurboForObjectDetection"), + ("owlv2", "Owlv2ForObjectDetection"), + ("owlvit", "OwlViTForObjectDetection"), + ] +) + +MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict( + [ + # Model for depth estimation mapping + ("depth_anything", "DepthAnythingForDepthEstimation"), + ("depth_pro", "DepthProForDepthEstimation"), + ("dpt", "DPTForDepthEstimation"), + ("glpn", "GLPNForDepthEstimation"), + ("prompt_depth_anything", "PromptDepthAnythingForDepthEstimation"), + ("zoedepth", "ZoeDepthForDepthEstimation"), + ] +) +MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( + [ + # Model for Seq2Seq Causal LM mapping + ("bart", "BartForConditionalGeneration"), + ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), + ("blenderbot", "BlenderbotForConditionalGeneration"), + ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), + ("encoder-decoder", "EncoderDecoderModel"), + ("fsmt", "FSMTForConditionalGeneration"), + ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), + ("granite_speech", "GraniteSpeechForConditionalGeneration"), + ("led", "LEDForConditionalGeneration"), + ("longt5", "LongT5ForConditionalGeneration"), + ("m2m_100", "M2M100ForConditionalGeneration"), + ("marian", "MarianMTModel"), + ("mbart", "MBartForConditionalGeneration"), + ("mt5", "MT5ForConditionalGeneration"), + ("mvp", "MvpForConditionalGeneration"), + ("nllb-moe", "NllbMoeForConditionalGeneration"), + ("pegasus", "PegasusForConditionalGeneration"), + ("pegasus_x", "PegasusXForConditionalGeneration"), + ("plbart", "PLBartForConditionalGeneration"), + ("prophetnet", "ProphetNetForConditionalGeneration"), + ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), + ("seamless_m4t", "SeamlessM4TForTextToText"), + ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"), + ("switch_transformers", "SwitchTransformersForConditionalGeneration"), + ("t5", "T5ForConditionalGeneration"), + ("t5gemma", "T5GemmaForConditionalGeneration"), + ("umt5", "UMT5ForConditionalGeneration"), + ("voxtral", "VoxtralForConditionalGeneration"), + ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), + ] +) + +MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( + [ + ("dia", "DiaForConditionalGeneration"), + ("granite_speech", "GraniteSpeechForConditionalGeneration"), + ("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"), + ("moonshine", "MoonshineForConditionalGeneration"), + ("pop2piano", "Pop2PianoForConditionalGeneration"), + ("seamless_m4t", "SeamlessM4TForSpeechToText"), + ("seamless_m4t_v2", "SeamlessM4Tv2ForSpeechToText"), + ("speech-encoder-decoder", "SpeechEncoderDecoderModel"), + ("speech_to_text", "Speech2TextForConditionalGeneration"), + ("speecht5", "SpeechT5ForSpeechToText"), + ("whisper", "WhisperForConditionalGeneration"), + ] +) + +MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Sequence Classification mapping + ("albert", "AlbertForSequenceClassification"), + ("arcee", "ArceeForSequenceClassification"), + ("bart", "BartForSequenceClassification"), + ("bert", "BertForSequenceClassification"), + ("big_bird", "BigBirdForSequenceClassification"), + ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"), + ("biogpt", "BioGptForSequenceClassification"), + ("bloom", "BloomForSequenceClassification"), + ("camembert", "CamembertForSequenceClassification"), + ("canine", "CanineForSequenceClassification"), + ("code_llama", "LlamaForSequenceClassification"), + ("convbert", "ConvBertForSequenceClassification"), + ("ctrl", "CTRLForSequenceClassification"), + ("data2vec-text", "Data2VecTextForSequenceClassification"), + ("deberta", "DebertaForSequenceClassification"), + ("deberta-v2", "DebertaV2ForSequenceClassification"), + ("deepseek_v2", "DeepseekV2ForSequenceClassification"), + ("deepseek_v3", "DeepseekV3ForSequenceClassification"), + ("diffllama", "DiffLlamaForSequenceClassification"), + ("distilbert", "DistilBertForSequenceClassification"), + ("doge", "DogeForSequenceClassification"), + ("electra", "ElectraForSequenceClassification"), + ("ernie", "ErnieForSequenceClassification"), + ("ernie_m", "ErnieMForSequenceClassification"), + ("esm", "EsmForSequenceClassification"), + ("exaone4", "Exaone4ForSequenceClassification"), + ("falcon", "FalconForSequenceClassification"), + ("flaubert", "FlaubertForSequenceClassification"), + ("fnet", "FNetForSequenceClassification"), + ("funnel", "FunnelForSequenceClassification"), + ("gemma", "GemmaForSequenceClassification"), + ("gemma2", "Gemma2ForSequenceClassification"), + ("gemma3", "Gemma3ForSequenceClassification"), + ("gemma3_text", "Gemma3TextForSequenceClassification"), + ("glm", "GlmForSequenceClassification"), + ("glm4", "Glm4ForSequenceClassification"), + ("gpt-sw3", "GPT2ForSequenceClassification"), + ("gpt2", "GPT2ForSequenceClassification"), + ("gpt_bigcode", "GPTBigCodeForSequenceClassification"), + ("gpt_neo", "GPTNeoForSequenceClassification"), + ("gpt_neox", "GPTNeoXForSequenceClassification"), + ("gpt_oss", "GptOssForSequenceClassification"), + ("gptj", "GPTJForSequenceClassification"), + ("helium", "HeliumForSequenceClassification"), + ("hunyuan_v1_dense", "HunYuanDenseV1ForSequenceClassification"), + ("hunyuan_v1_moe", "HunYuanMoEV1ForSequenceClassification"), + ("ibert", "IBertForSequenceClassification"), + ("jamba", "JambaForSequenceClassification"), + ("jetmoe", "JetMoeForSequenceClassification"), + ("layoutlm", "LayoutLMForSequenceClassification"), + ("layoutlmv2", "LayoutLMv2ForSequenceClassification"), + ("layoutlmv3", "LayoutLMv3ForSequenceClassification"), + ("led", "LEDForSequenceClassification"), + ("lilt", "LiltForSequenceClassification"), + ("llama", "LlamaForSequenceClassification"), + ("longformer", "LongformerForSequenceClassification"), + ("luke", "LukeForSequenceClassification"), + ("markuplm", "MarkupLMForSequenceClassification"), + ("mbart", "MBartForSequenceClassification"), + ("mega", "MegaForSequenceClassification"), + ("megatron-bert", "MegatronBertForSequenceClassification"), + ("minimax", "MiniMaxForSequenceClassification"), + ("ministral", "MinistralForSequenceClassification"), + ("mistral", "MistralForSequenceClassification"), + ("mixtral", "MixtralForSequenceClassification"), + ("mobilebert", "MobileBertForSequenceClassification"), + ("modernbert", "ModernBertForSequenceClassification"), + ("modernbert-decoder", "ModernBertDecoderForSequenceClassification"), + ("mpnet", "MPNetForSequenceClassification"), + ("mpt", "MptForSequenceClassification"), + ("mra", "MraForSequenceClassification"), + ("mt5", "MT5ForSequenceClassification"), + ("mvp", "MvpForSequenceClassification"), + ("nemotron", "NemotronForSequenceClassification"), + ("nezha", "NezhaForSequenceClassification"), + ("nystromformer", "NystromformerForSequenceClassification"), + ("open-llama", "OpenLlamaForSequenceClassification"), + ("openai-gpt", "OpenAIGPTForSequenceClassification"), + ("opt", "OPTForSequenceClassification"), + ("perceiver", "PerceiverForSequenceClassification"), + ("persimmon", "PersimmonForSequenceClassification"), + ("phi", "PhiForSequenceClassification"), + ("phi3", "Phi3ForSequenceClassification"), + ("phimoe", "PhimoeForSequenceClassification"), + ("plbart", "PLBartForSequenceClassification"), + ("qdqbert", "QDQBertForSequenceClassification"), + ("qwen2", "Qwen2ForSequenceClassification"), + ("qwen2_moe", "Qwen2MoeForSequenceClassification"), + ("qwen3", "Qwen3ForSequenceClassification"), + ("qwen3_moe", "Qwen3MoeForSequenceClassification"), + ("qwen3_next", "Qwen3NextForSequenceClassification"), + ("reformer", "ReformerForSequenceClassification"), + ("rembert", "RemBertForSequenceClassification"), + ("roberta", "RobertaForSequenceClassification"), + ("roberta-prelayernorm", "RobertaPreLayerNormForSequenceClassification"), + ("roc_bert", "RoCBertForSequenceClassification"), + ("roformer", "RoFormerForSequenceClassification"), + ("seed_oss", "SeedOssForSequenceClassification"), + ("smollm3", "SmolLM3ForSequenceClassification"), + ("squeezebert", "SqueezeBertForSequenceClassification"), + ("stablelm", "StableLmForSequenceClassification"), + ("starcoder2", "Starcoder2ForSequenceClassification"), + ("t5", "T5ForSequenceClassification"), + ("t5gemma", "T5GemmaForSequenceClassification"), + ("tapas", "TapasForSequenceClassification"), + ("transfo-xl", "TransfoXLForSequenceClassification"), + ("umt5", "UMT5ForSequenceClassification"), + ("xlm", "XLMForSequenceClassification"), + ("xlm-roberta", "XLMRobertaForSequenceClassification"), + ("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"), + ("xlnet", "XLNetForSequenceClassification"), + ("xmod", "XmodForSequenceClassification"), + ("yoso", "YosoForSequenceClassification"), + ("zamba", "ZambaForSequenceClassification"), + ("zamba2", "Zamba2ForSequenceClassification"), + ] +) + +MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( + [ + # Model for Question Answering mapping + ("albert", "AlbertForQuestionAnswering"), + ("arcee", "ArceeForQuestionAnswering"), + ("bart", "BartForQuestionAnswering"), + ("bert", "BertForQuestionAnswering"), + ("big_bird", "BigBirdForQuestionAnswering"), + ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"), + ("bloom", "BloomForQuestionAnswering"), + ("camembert", "CamembertForQuestionAnswering"), + ("canine", "CanineForQuestionAnswering"), + ("convbert", "ConvBertForQuestionAnswering"), + ("data2vec-text", "Data2VecTextForQuestionAnswering"), + ("deberta", "DebertaForQuestionAnswering"), + ("deberta-v2", "DebertaV2ForQuestionAnswering"), + ("diffllama", "DiffLlamaForQuestionAnswering"), + ("distilbert", "DistilBertForQuestionAnswering"), + ("electra", "ElectraForQuestionAnswering"), + ("ernie", "ErnieForQuestionAnswering"), + ("ernie_m", "ErnieMForQuestionAnswering"), + ("exaone4", "Exaone4ForQuestionAnswering"), + ("falcon", "FalconForQuestionAnswering"), + ("flaubert", "FlaubertForQuestionAnsweringSimple"), + ("fnet", "FNetForQuestionAnswering"), + ("funnel", "FunnelForQuestionAnswering"), + ("gpt2", "GPT2ForQuestionAnswering"), + ("gpt_neo", "GPTNeoForQuestionAnswering"), + ("gpt_neox", "GPTNeoXForQuestionAnswering"), + ("gptj", "GPTJForQuestionAnswering"), + ("ibert", "IBertForQuestionAnswering"), + ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), + ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), + ("led", "LEDForQuestionAnswering"), + ("lilt", "LiltForQuestionAnswering"), + ("llama", "LlamaForQuestionAnswering"), + ("longformer", "LongformerForQuestionAnswering"), + ("luke", "LukeForQuestionAnswering"), + ("lxmert", "LxmertForQuestionAnswering"), + ("markuplm", "MarkupLMForQuestionAnswering"), + ("mbart", "MBartForQuestionAnswering"), + ("mega", "MegaForQuestionAnswering"), + ("megatron-bert", "MegatronBertForQuestionAnswering"), + ("minimax", "MiniMaxForQuestionAnswering"), + ("ministral", "MinistralForQuestionAnswering"), + ("mistral", "MistralForQuestionAnswering"), + ("mixtral", "MixtralForQuestionAnswering"), + ("mobilebert", "MobileBertForQuestionAnswering"), + ("modernbert", "ModernBertForQuestionAnswering"), + ("mpnet", "MPNetForQuestionAnswering"), + ("mpt", "MptForQuestionAnswering"), + ("mra", "MraForQuestionAnswering"), + ("mt5", "MT5ForQuestionAnswering"), + ("mvp", "MvpForQuestionAnswering"), + ("nemotron", "NemotronForQuestionAnswering"), + ("nezha", "NezhaForQuestionAnswering"), + ("nystromformer", "NystromformerForQuestionAnswering"), + ("opt", "OPTForQuestionAnswering"), + ("qdqbert", "QDQBertForQuestionAnswering"), + ("qwen2", "Qwen2ForQuestionAnswering"), + ("qwen2_moe", "Qwen2MoeForQuestionAnswering"), + ("qwen3", "Qwen3ForQuestionAnswering"), + ("qwen3_moe", "Qwen3MoeForQuestionAnswering"), + ("qwen3_next", "Qwen3NextForQuestionAnswering"), + ("reformer", "ReformerForQuestionAnswering"), + ("rembert", "RemBertForQuestionAnswering"), + ("roberta", "RobertaForQuestionAnswering"), + ("roberta-prelayernorm", "RobertaPreLayerNormForQuestionAnswering"), + ("roc_bert", "RoCBertForQuestionAnswering"), + ("roformer", "RoFormerForQuestionAnswering"), + ("seed_oss", "SeedOssForQuestionAnswering"), + ("smollm3", "SmolLM3ForQuestionAnswering"), + ("splinter", "SplinterForQuestionAnswering"), + ("squeezebert", "SqueezeBertForQuestionAnswering"), + ("t5", "T5ForQuestionAnswering"), + ("umt5", "UMT5ForQuestionAnswering"), + ("xlm", "XLMForQuestionAnsweringSimple"), + ("xlm-roberta", "XLMRobertaForQuestionAnswering"), + ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"), + ("xlnet", "XLNetForQuestionAnsweringSimple"), + ("xmod", "XmodForQuestionAnswering"), + ("yoso", "YosoForQuestionAnswering"), + ] +) + +MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( + [ + # Model for Table Question Answering mapping + ("tapas", "TapasForQuestionAnswering"), + ] +) + +MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( + [ + ("blip", "BlipForQuestionAnswering"), + ("blip-2", "Blip2ForConditionalGeneration"), + ("vilt", "ViltForQuestionAnswering"), + ] +) + +MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( + [ + ("layoutlm", "LayoutLMForQuestionAnswering"), + ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), + ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), + ] +) + +MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Token Classification mapping + ("albert", "AlbertForTokenClassification"), + ("apertus", "ApertusForTokenClassification"), + ("arcee", "ArceeForTokenClassification"), + ("bert", "BertForTokenClassification"), + ("big_bird", "BigBirdForTokenClassification"), + ("biogpt", "BioGptForTokenClassification"), + ("bloom", "BloomForTokenClassification"), + ("bros", "BrosForTokenClassification"), + ("camembert", "CamembertForTokenClassification"), + ("canine", "CanineForTokenClassification"), + ("convbert", "ConvBertForTokenClassification"), + ("data2vec-text", "Data2VecTextForTokenClassification"), + ("deberta", "DebertaForTokenClassification"), + ("deberta-v2", "DebertaV2ForTokenClassification"), + ("deepseek_v3", "DeepseekV3ForTokenClassification"), + ("diffllama", "DiffLlamaForTokenClassification"), + ("distilbert", "DistilBertForTokenClassification"), + ("electra", "ElectraForTokenClassification"), + ("ernie", "ErnieForTokenClassification"), + ("ernie_m", "ErnieMForTokenClassification"), + ("esm", "EsmForTokenClassification"), + ("exaone4", "Exaone4ForTokenClassification"), + ("falcon", "FalconForTokenClassification"), + ("flaubert", "FlaubertForTokenClassification"), + ("fnet", "FNetForTokenClassification"), + ("funnel", "FunnelForTokenClassification"), + ("gemma", "GemmaForTokenClassification"), + ("gemma2", "Gemma2ForTokenClassification"), + ("glm", "GlmForTokenClassification"), + ("glm4", "Glm4ForTokenClassification"), + ("gpt-sw3", "GPT2ForTokenClassification"), + ("gpt2", "GPT2ForTokenClassification"), + ("gpt_bigcode", "GPTBigCodeForTokenClassification"), + ("gpt_neo", "GPTNeoForTokenClassification"), + ("gpt_neox", "GPTNeoXForTokenClassification"), + ("gpt_oss", "GptOssForTokenClassification"), + ("helium", "HeliumForTokenClassification"), + ("ibert", "IBertForTokenClassification"), + ("layoutlm", "LayoutLMForTokenClassification"), + ("layoutlmv2", "LayoutLMv2ForTokenClassification"), + ("layoutlmv3", "LayoutLMv3ForTokenClassification"), + ("lilt", "LiltForTokenClassification"), + ("llama", "LlamaForTokenClassification"), + ("longformer", "LongformerForTokenClassification"), + ("luke", "LukeForTokenClassification"), + ("markuplm", "MarkupLMForTokenClassification"), + ("mega", "MegaForTokenClassification"), + ("megatron-bert", "MegatronBertForTokenClassification"), + ("minimax", "MiniMaxForTokenClassification"), + ("ministral", "MinistralForTokenClassification"), + ("mistral", "MistralForTokenClassification"), + ("mixtral", "MixtralForTokenClassification"), + ("mobilebert", "MobileBertForTokenClassification"), + ("modernbert", "ModernBertForTokenClassification"), + ("mpnet", "MPNetForTokenClassification"), + ("mpt", "MptForTokenClassification"), + ("mra", "MraForTokenClassification"), + ("mt5", "MT5ForTokenClassification"), + ("nemotron", "NemotronForTokenClassification"), + ("nezha", "NezhaForTokenClassification"), + ("nystromformer", "NystromformerForTokenClassification"), + ("persimmon", "PersimmonForTokenClassification"), + ("phi", "PhiForTokenClassification"), + ("phi3", "Phi3ForTokenClassification"), + ("qdqbert", "QDQBertForTokenClassification"), + ("qwen2", "Qwen2ForTokenClassification"), + ("qwen2_moe", "Qwen2MoeForTokenClassification"), + ("qwen3", "Qwen3ForTokenClassification"), + ("qwen3_moe", "Qwen3MoeForTokenClassification"), + ("qwen3_next", "Qwen3NextForTokenClassification"), + ("rembert", "RemBertForTokenClassification"), + ("roberta", "RobertaForTokenClassification"), + ("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"), + ("roc_bert", "RoCBertForTokenClassification"), + ("roformer", "RoFormerForTokenClassification"), + ("seed_oss", "SeedOssForTokenClassification"), + ("smollm3", "SmolLM3ForTokenClassification"), + ("squeezebert", "SqueezeBertForTokenClassification"), + ("stablelm", "StableLmForTokenClassification"), + ("starcoder2", "Starcoder2ForTokenClassification"), + ("t5", "T5ForTokenClassification"), + ("t5gemma", "T5GemmaForTokenClassification"), + ("umt5", "UMT5ForTokenClassification"), + ("xlm", "XLMForTokenClassification"), + ("xlm-roberta", "XLMRobertaForTokenClassification"), + ("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"), + ("xlnet", "XLNetForTokenClassification"), + ("xmod", "XmodForTokenClassification"), + ("yoso", "YosoForTokenClassification"), + ] +) + +MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( + [ + # Model for Multiple Choice mapping + ("albert", "AlbertForMultipleChoice"), + ("bert", "BertForMultipleChoice"), + ("big_bird", "BigBirdForMultipleChoice"), + ("camembert", "CamembertForMultipleChoice"), + ("canine", "CanineForMultipleChoice"), + ("convbert", "ConvBertForMultipleChoice"), + ("data2vec-text", "Data2VecTextForMultipleChoice"), + ("deberta-v2", "DebertaV2ForMultipleChoice"), + ("distilbert", "DistilBertForMultipleChoice"), + ("electra", "ElectraForMultipleChoice"), + ("ernie", "ErnieForMultipleChoice"), + ("ernie_m", "ErnieMForMultipleChoice"), + ("flaubert", "FlaubertForMultipleChoice"), + ("fnet", "FNetForMultipleChoice"), + ("funnel", "FunnelForMultipleChoice"), + ("ibert", "IBertForMultipleChoice"), + ("longformer", "LongformerForMultipleChoice"), + ("luke", "LukeForMultipleChoice"), + ("mega", "MegaForMultipleChoice"), + ("megatron-bert", "MegatronBertForMultipleChoice"), + ("mobilebert", "MobileBertForMultipleChoice"), + ("modernbert", "ModernBertForMultipleChoice"), + ("mpnet", "MPNetForMultipleChoice"), + ("mra", "MraForMultipleChoice"), + ("nezha", "NezhaForMultipleChoice"), + ("nystromformer", "NystromformerForMultipleChoice"), + ("qdqbert", "QDQBertForMultipleChoice"), + ("rembert", "RemBertForMultipleChoice"), + ("roberta", "RobertaForMultipleChoice"), + ("roberta-prelayernorm", "RobertaPreLayerNormForMultipleChoice"), + ("roc_bert", "RoCBertForMultipleChoice"), + ("roformer", "RoFormerForMultipleChoice"), + ("squeezebert", "SqueezeBertForMultipleChoice"), + ("xlm", "XLMForMultipleChoice"), + ("xlm-roberta", "XLMRobertaForMultipleChoice"), + ("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"), + ("xlnet", "XLNetForMultipleChoice"), + ("xmod", "XmodForMultipleChoice"), + ("yoso", "YosoForMultipleChoice"), + ] +) + +MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( + [ + ("bert", "BertForNextSentencePrediction"), + ("ernie", "ErnieForNextSentencePrediction"), + ("fnet", "FNetForNextSentencePrediction"), + ("megatron-bert", "MegatronBertForNextSentencePrediction"), + ("mobilebert", "MobileBertForNextSentencePrediction"), + ("nezha", "NezhaForNextSentencePrediction"), + ("qdqbert", "QDQBertForNextSentencePrediction"), + ] +) + +MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Audio Classification mapping + ("audio-spectrogram-transformer", "ASTForAudioClassification"), + ("data2vec-audio", "Data2VecAudioForSequenceClassification"), + ("hubert", "HubertForSequenceClassification"), + ("sew", "SEWForSequenceClassification"), + ("sew-d", "SEWDForSequenceClassification"), + ("unispeech", "UniSpeechForSequenceClassification"), + ("unispeech-sat", "UniSpeechSatForSequenceClassification"), + ("wav2vec2", "Wav2Vec2ForSequenceClassification"), + ("wav2vec2-bert", "Wav2Vec2BertForSequenceClassification"), + ("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"), + ("wavlm", "WavLMForSequenceClassification"), + ("whisper", "WhisperForAudioClassification"), + ] +) + +MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict( + [ + # Model for Connectionist temporal classification (CTC) mapping + ("data2vec-audio", "Data2VecAudioForCTC"), + ("hubert", "HubertForCTC"), + ("mctct", "MCTCTForCTC"), + ("parakeet_ctc", "ParakeetForCTC"), + ("sew", "SEWForCTC"), + ("sew-d", "SEWDForCTC"), + ("unispeech", "UniSpeechForCTC"), + ("unispeech-sat", "UniSpeechSatForCTC"), + ("wav2vec2", "Wav2Vec2ForCTC"), + ("wav2vec2-bert", "Wav2Vec2BertForCTC"), + ("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"), + ("wavlm", "WavLMForCTC"), + ] +) + +MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Audio Classification mapping + ("data2vec-audio", "Data2VecAudioForAudioFrameClassification"), + ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"), + ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"), + ("wav2vec2-bert", "Wav2Vec2BertForAudioFrameClassification"), + ("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"), + ("wavlm", "WavLMForAudioFrameClassification"), + ] +) + +MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict( + [ + # Model for Audio Classification mapping + ("data2vec-audio", "Data2VecAudioForXVector"), + ("unispeech-sat", "UniSpeechSatForXVector"), + ("wav2vec2", "Wav2Vec2ForXVector"), + ("wav2vec2-bert", "Wav2Vec2BertForXVector"), + ("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"), + ("wavlm", "WavLMForXVector"), + ] +) + +MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict( + [ + # Model for Text-To-Spectrogram mapping + ("fastspeech2_conformer", "FastSpeech2ConformerModel"), + ("speecht5", "SpeechT5ForTextToSpeech"), + ] +) + +MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict( + [ + # Model for Text-To-Waveform mapping + ("bark", "BarkModel"), + ("csm", "CsmForConditionalGeneration"), + ("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"), + ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"), + ("musicgen", "MusicgenForConditionalGeneration"), + ("musicgen_melody", "MusicgenMelodyForConditionalGeneration"), + ("qwen2_5_omni", "Qwen2_5OmniForConditionalGeneration"), + ("qwen3_omni_moe", "Qwen3OmniMoeForConditionalGeneration"), + ("seamless_m4t", "SeamlessM4TForTextToSpeech"), + ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToSpeech"), + ("vits", "VitsModel"), + ] +) + +MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Zero Shot Image Classification mapping + ("align", "AlignModel"), + ("altclip", "AltCLIPModel"), + ("blip", "BlipModel"), + ("blip-2", "Blip2ForImageTextRetrieval"), + ("chinese_clip", "ChineseCLIPModel"), + ("clip", "CLIPModel"), + ("clipseg", "CLIPSegModel"), + ("metaclip_2", "MetaClip2Model"), + ("siglip", "SiglipModel"), + ("siglip2", "Siglip2Model"), + ] +) + +MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( + [ + # Backbone mapping + ("beit", "BeitBackbone"), + ("bit", "BitBackbone"), + ("convnext", "ConvNextBackbone"), + ("convnextv2", "ConvNextV2Backbone"), + ("dinat", "DinatBackbone"), + ("dinov2", "Dinov2Backbone"), + ("dinov2_with_registers", "Dinov2WithRegistersBackbone"), + ("focalnet", "FocalNetBackbone"), + ("hgnet_v2", "HGNetV2Backbone"), + ("hiera", "HieraBackbone"), + ("maskformer-swin", "MaskFormerSwinBackbone"), + ("nat", "NatBackbone"), + ("pvt_v2", "PvtV2Backbone"), + ("resnet", "ResNetBackbone"), + ("rt_detr_resnet", "RTDetrResNetBackbone"), + ("swin", "SwinBackbone"), + ("swinv2", "Swinv2Backbone"), + ("textnet", "TextNetBackbone"), + ("timm_backbone", "TimmBackbone"), + ("vitdet", "VitDetBackbone"), + ("vitpose_backbone", "VitPoseBackbone"), + ] +) + +MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( + [ + ("edgetam", "EdgeTamModel"), + ("edgetam_video", "EdgeTamModel"), + ("sam", "SamModel"), + ("sam2", "Sam2Model"), + ("sam2_video", "Sam2Model"), + ("sam_hq", "SamHQModel"), + ] +) + + +MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict( + [ + ("superpoint", "SuperPointForKeypointDetection"), + ] +) + +MODEL_FOR_KEYPOINT_MATCHING_MAPPING_NAMES = OrderedDict( + [ + ("efficientloftr", "EfficientLoFTRForKeypointMatching"), + ("lightglue", "LightGlueForKeypointMatching"), + ("superglue", "SuperGlueForKeypointMatching"), + ] +) + +MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict( + [ + ("albert", "AlbertModel"), + ("bert", "BertModel"), + ("big_bird", "BigBirdModel"), + ("clip_text_model", "CLIPTextModel"), + ("data2vec-text", "Data2VecTextModel"), + ("deberta", "DebertaModel"), + ("deberta-v2", "DebertaV2Model"), + ("distilbert", "DistilBertModel"), + ("electra", "ElectraModel"), + ("emu3", "Emu3TextModel"), + ("flaubert", "FlaubertModel"), + ("ibert", "IBertModel"), + ("llama4", "Llama4TextModel"), + ("longformer", "LongformerModel"), + ("mllama", "MllamaTextModel"), + ("mobilebert", "MobileBertModel"), + ("mt5", "MT5EncoderModel"), + ("nystromformer", "NystromformerModel"), + ("reformer", "ReformerModel"), + ("rembert", "RemBertModel"), + ("roberta", "RobertaModel"), + ("roberta-prelayernorm", "RobertaPreLayerNormModel"), + ("roc_bert", "RoCBertModel"), + ("roformer", "RoFormerModel"), + ("squeezebert", "SqueezeBertModel"), + ("t5", "T5EncoderModel"), + ("t5gemma", "T5GemmaEncoderModel"), + ("umt5", "UMT5EncoderModel"), + ("xlm", "XLMModel"), + ("xlm-roberta", "XLMRobertaModel"), + ("xlm-roberta-xl", "XLMRobertaXLModel"), + ] +) + +MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + ("patchtsmixer", "PatchTSMixerForTimeSeriesClassification"), + ("patchtst", "PatchTSTForClassification"), + ] +) + +MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict( + [ + ("patchtsmixer", "PatchTSMixerForRegression"), + ("patchtst", "PatchTSTForRegression"), + ] +) + +MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES = OrderedDict( + [ + ("timesfm", "TimesFmModelForPrediction"), + ] +) + +MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict( + [ + ("swin2sr", "Swin2SRForImageSuperResolution"), + ] +) + +MODEL_FOR_AUDIO_TOKENIZATION_NAMES = OrderedDict( + [ + ("dac", "DacModel"), + ] +) + +MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) +MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) +MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) +MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES +) +MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES +) +MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES +) +MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES +) +MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES +) +MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES +) +MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES +) +MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES +) +MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) +MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES +) +MODEL_FOR_RETRIEVAL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_RETRIEVAL_MAPPING_NAMES) +MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES +) +MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES +) +MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) +MODEL_FOR_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_MAPPING_NAMES) +MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES +) +MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES) +MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES +) +MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES) +MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES +) +MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES +) +MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES +) +MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES +) +MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES +) +MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES) +MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES +) +MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES +) +MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES) +MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) +MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES +) +MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES) + +MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES +) + +MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES) + +MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES) + +MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES) + +MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES +) + +MODEL_FOR_KEYPOINT_MATCHING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_MATCHING_MAPPING_NAMES) + +MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES) + +MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES +) + +MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES +) + +MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES +) + +MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES) + +MODEL_FOR_AUDIO_TOKENIZATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_TOKENIZATION_NAMES) + + +class AutoModelForMaskGeneration(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING + + +class AutoModelForKeypointDetection(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_KEYPOINT_DETECTION_MAPPING + + +class AutoModelForKeypointMatching(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_KEYPOINT_MATCHING_MAPPING + + +class AutoModelForTextEncoding(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING + + +class AutoModelForImageToImage(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING + + +class AutoModel(_BaseAutoModelClass): + _model_mapping = MODEL_MAPPING + + +AutoModel = auto_class_update(AutoModel) + + +class AutoModelForPreTraining(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_PRETRAINING_MAPPING + + +AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") + + +# Private on purpose, the public class will add the deprecation warnings. +class _AutoModelWithLMHead(_BaseAutoModelClass): + _model_mapping = MODEL_WITH_LM_HEAD_MAPPING + + +_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") + + +class AutoModelForCausalLM(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING + + # override to give better return typehint + @classmethod + def from_pretrained( + cls: type["AutoModelForCausalLM"], + pretrained_model_name_or_path: Union[str, os.PathLike[str]], + *model_args, + **kwargs, + ) -> "_BaseModelWithGenerate": + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + +AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") + + +class AutoModelForMaskedLM(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_MASKED_LM_MAPPING + + +AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") + + +class AutoModelForSeq2SeqLM(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING + + +AutoModelForSeq2SeqLM = auto_class_update( + AutoModelForSeq2SeqLM, + head_doc="sequence-to-sequence language modeling", + checkpoint_for_example="google-t5/t5-base", +) + + +class AutoModelForSequenceClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING + + +AutoModelForSequenceClassification = auto_class_update( + AutoModelForSequenceClassification, head_doc="sequence classification" +) + + +class AutoModelForQuestionAnswering(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING + + +AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") + + +class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING + + +AutoModelForTableQuestionAnswering = auto_class_update( + AutoModelForTableQuestionAnswering, + head_doc="table question answering", + checkpoint_for_example="google/tapas-base-finetuned-wtq", +) + + +class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING + + +AutoModelForVisualQuestionAnswering = auto_class_update( + AutoModelForVisualQuestionAnswering, + head_doc="visual question answering", + checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa", +) + + +class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING + + +AutoModelForDocumentQuestionAnswering = auto_class_update( + AutoModelForDocumentQuestionAnswering, + head_doc="document question answering", + checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3', +) + + +class AutoModelForTokenClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING + + +AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") + + +class AutoModelForMultipleChoice(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING + + +AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") + + +class AutoModelForNextSentencePrediction(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING + + +AutoModelForNextSentencePrediction = auto_class_update( + AutoModelForNextSentencePrediction, head_doc="next sentence prediction" +) + + +class AutoModelForImageClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING + + +AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") + + +class AutoModelForZeroShotImageClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING + + +AutoModelForZeroShotImageClassification = auto_class_update( + AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification" +) + + +class AutoModelForImageSegmentation(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING + + +AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation") + + +class AutoModelForSemanticSegmentation(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING + + +AutoModelForSemanticSegmentation = auto_class_update( + AutoModelForSemanticSegmentation, head_doc="semantic segmentation" +) + + +class AutoModelForTimeSeriesPrediction(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING + + +AutoModelForTimeSeriesPrediction = auto_class_update( + AutoModelForTimeSeriesPrediction, head_doc="time-series prediction" +) + + +class AutoModelForUniversalSegmentation(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING + + +AutoModelForUniversalSegmentation = auto_class_update( + AutoModelForUniversalSegmentation, head_doc="universal image segmentation" +) + + +class AutoModelForInstanceSegmentation(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING + + +AutoModelForInstanceSegmentation = auto_class_update( + AutoModelForInstanceSegmentation, head_doc="instance segmentation" +) + + +class AutoModelForObjectDetection(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING + + +AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection") + + +class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING + + +AutoModelForZeroShotObjectDetection = auto_class_update( + AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection" +) + + +class AutoModelForDepthEstimation(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING + + +AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation") + + +class AutoModelForVideoClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING + + +AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification") + + +# Private on purpose, the public class will add the deprecation warnings. +class _AutoModelForVision2Seq(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING + + +_AutoModelForVision2Seq = auto_class_update(_AutoModelForVision2Seq, head_doc="vision-to-text modeling") + + +class AutoModelForImageTextToText(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING + + # override to give better return typehint + @classmethod + def from_pretrained( + cls: type["AutoModelForImageTextToText"], + pretrained_model_name_or_path: Union[str, os.PathLike[str]], + *model_args, + **kwargs, + ) -> "_BaseModelWithGenerate": + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + +AutoModelForImageTextToText = auto_class_update(AutoModelForImageTextToText, head_doc="image-text-to-text modeling") + + +class AutoModelForAudioClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING + + +AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification") + + +class AutoModelForCTC(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_CTC_MAPPING + + +AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification") + + +class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING + + +AutoModelForSpeechSeq2Seq = auto_class_update( + AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" +) + + +class AutoModelForAudioFrameClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING + + +AutoModelForAudioFrameClassification = auto_class_update( + AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification" +) + + +class AutoModelForAudioXVector(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING + + +class AutoModelForTextToSpectrogram(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING + + +class AutoModelForTextToWaveform(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING + + +class AutoBackbone(_BaseAutoBackboneClass): + _model_mapping = MODEL_FOR_BACKBONE_MAPPING + + +AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector") + + +class AutoModelForMaskedImageModeling(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING + + +AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling") + + +class AutoModelForAudioTokenization(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_AUDIO_TOKENIZATION_MAPPING + + +AutoModelForAudioTokenization = auto_class_update( + AutoModelForAudioTokenization, head_doc="audio tokenization through codebooks" +) + + +class AutoModelWithLMHead(_AutoModelWithLMHead): + @classmethod + def from_config(cls, config, **kwargs): + warnings.warn( + "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " + "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " + "`AutoModelForSeq2SeqLM` for encoder-decoder models.", + FutureWarning, + ) + return super().from_config(config, **kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + warnings.warn( + "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " + "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " + "`AutoModelForSeq2SeqLM` for encoder-decoder models.", + FutureWarning, + ) + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + +class AutoModelForVision2Seq(_AutoModelForVision2Seq): + @classmethod + def from_config(cls, config, **kwargs): + warnings.warn( + "The class `AutoModelForVision2Seq` is deprecated and will be removed in v5.0. Please use " + "`AutoModelForImageTextToText` instead.", + FutureWarning, + ) + return super().from_config(config, **kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + warnings.warn( + "The class `AutoModelForVision2Seq` is deprecated and will be removed in v5.0. Please use " + "`AutoModelForImageTextToText` instead.", + FutureWarning, + ) + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + +__all__ = [ + "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", + "MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING", + "MODEL_FOR_AUDIO_TOKENIZATION_MAPPING", + "MODEL_FOR_AUDIO_XVECTOR_MAPPING", + "MODEL_FOR_BACKBONE_MAPPING", + "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING", + "MODEL_FOR_CAUSAL_LM_MAPPING", + "MODEL_FOR_CTC_MAPPING", + "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_DEPTH_ESTIMATION_MAPPING", + "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", + "MODEL_FOR_IMAGE_MAPPING", + "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING", + "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING", + "MODEL_FOR_KEYPOINT_DETECTION_MAPPING", + "MODEL_FOR_KEYPOINT_MATCHING_MAPPING", + "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", + "MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING", + "MODEL_FOR_MASKED_LM_MAPPING", + "MODEL_FOR_MASK_GENERATION_MAPPING", + "MODEL_FOR_MULTIPLE_CHOICE_MAPPING", + "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", + "MODEL_FOR_OBJECT_DETECTION_MAPPING", + "MODEL_FOR_PRETRAINING_MAPPING", + "MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING", + "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", + "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", + "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", + "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_TEXT_ENCODING_MAPPING", + "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING", + "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING", + "MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING", + "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", + "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING", + "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING", + "MODEL_FOR_VISION_2_SEQ_MAPPING", + "MODEL_FOR_RETRIEVAL_MAPPING", + "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING", + "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING", + "MODEL_MAPPING", + "MODEL_WITH_LM_HEAD_MAPPING", + "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING", + "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING", + "MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING", + "MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING", + "AutoModel", + "AutoBackbone", + "AutoModelForAudioClassification", + "AutoModelForAudioFrameClassification", + "AutoModelForAudioTokenization", + "AutoModelForAudioXVector", + "AutoModelForCausalLM", + "AutoModelForCTC", + "AutoModelForDepthEstimation", + "AutoModelForImageClassification", + "AutoModelForImageSegmentation", + "AutoModelForImageToImage", + "AutoModelForInstanceSegmentation", + "AutoModelForKeypointDetection", + "AutoModelForKeypointMatching", + "AutoModelForMaskGeneration", + "AutoModelForTextEncoding", + "AutoModelForMaskedImageModeling", + "AutoModelForMaskedLM", + "AutoModelForMultipleChoice", + "AutoModelForNextSentencePrediction", + "AutoModelForObjectDetection", + "AutoModelForPreTraining", + "AutoModelForQuestionAnswering", + "AutoModelForSemanticSegmentation", + "AutoModelForSeq2SeqLM", + "AutoModelForSequenceClassification", + "AutoModelForSpeechSeq2Seq", + "AutoModelForTableQuestionAnswering", + "AutoModelForTextToSpectrogram", + "AutoModelForTextToWaveform", + "AutoModelForTimeSeriesPrediction", + "AutoModelForTokenClassification", + "AutoModelForUniversalSegmentation", + "AutoModelForVideoClassification", + "AutoModelForVision2Seq", + "AutoModelForVisualQuestionAnswering", + "AutoModelForDocumentQuestionAnswering", + "AutoModelWithLMHead", + "AutoModelForZeroShotImageClassification", + "AutoModelForZeroShotObjectDetection", + "AutoModelForImageTextToText", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_flax_auto.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_flax_auto.py new file mode 100644 index 0000000000000000000000000000000000000000..0588d03cb6cdb43b94cc3fcd73b1791d1a5ee809 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_flax_auto.py @@ -0,0 +1,413 @@ +# coding=utf-8 +# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Auto Model class.""" + +from collections import OrderedDict + +from ...utils import logging +from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update +from .configuration_auto import CONFIG_MAPPING_NAMES + + +logger = logging.get_logger(__name__) + + +FLAX_MODEL_MAPPING_NAMES = OrderedDict( + [ + # Base model mapping + ("albert", "FlaxAlbertModel"), + ("bart", "FlaxBartModel"), + ("beit", "FlaxBeitModel"), + ("bert", "FlaxBertModel"), + ("big_bird", "FlaxBigBirdModel"), + ("blenderbot", "FlaxBlenderbotModel"), + ("blenderbot-small", "FlaxBlenderbotSmallModel"), + ("bloom", "FlaxBloomModel"), + ("clip", "FlaxCLIPModel"), + ("dinov2", "FlaxDinov2Model"), + ("distilbert", "FlaxDistilBertModel"), + ("electra", "FlaxElectraModel"), + ("gemma", "FlaxGemmaModel"), + ("gpt-sw3", "FlaxGPT2Model"), + ("gpt2", "FlaxGPT2Model"), + ("gpt_neo", "FlaxGPTNeoModel"), + ("gptj", "FlaxGPTJModel"), + ("llama", "FlaxLlamaModel"), + ("longt5", "FlaxLongT5Model"), + ("marian", "FlaxMarianModel"), + ("mbart", "FlaxMBartModel"), + ("mistral", "FlaxMistralModel"), + ("mt5", "FlaxMT5Model"), + ("opt", "FlaxOPTModel"), + ("pegasus", "FlaxPegasusModel"), + ("regnet", "FlaxRegNetModel"), + ("resnet", "FlaxResNetModel"), + ("roberta", "FlaxRobertaModel"), + ("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"), + ("roformer", "FlaxRoFormerModel"), + ("t5", "FlaxT5Model"), + ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"), + ("vit", "FlaxViTModel"), + ("wav2vec2", "FlaxWav2Vec2Model"), + ("whisper", "FlaxWhisperModel"), + ("xglm", "FlaxXGLMModel"), + ("xlm-roberta", "FlaxXLMRobertaModel"), + ] +) + +FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( + [ + # Model for pre-training mapping + ("albert", "FlaxAlbertForPreTraining"), + ("bart", "FlaxBartForConditionalGeneration"), + ("bert", "FlaxBertForPreTraining"), + ("big_bird", "FlaxBigBirdForPreTraining"), + ("electra", "FlaxElectraForPreTraining"), + ("longt5", "FlaxLongT5ForConditionalGeneration"), + ("mbart", "FlaxMBartForConditionalGeneration"), + ("mt5", "FlaxMT5ForConditionalGeneration"), + ("roberta", "FlaxRobertaForMaskedLM"), + ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"), + ("roformer", "FlaxRoFormerForMaskedLM"), + ("t5", "FlaxT5ForConditionalGeneration"), + ("wav2vec2", "FlaxWav2Vec2ForPreTraining"), + ("whisper", "FlaxWhisperForConditionalGeneration"), + ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), + ] +) + +FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( + [ + # Model for Masked LM mapping + ("albert", "FlaxAlbertForMaskedLM"), + ("bart", "FlaxBartForConditionalGeneration"), + ("bert", "FlaxBertForMaskedLM"), + ("big_bird", "FlaxBigBirdForMaskedLM"), + ("distilbert", "FlaxDistilBertForMaskedLM"), + ("electra", "FlaxElectraForMaskedLM"), + ("mbart", "FlaxMBartForConditionalGeneration"), + ("roberta", "FlaxRobertaForMaskedLM"), + ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"), + ("roformer", "FlaxRoFormerForMaskedLM"), + ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), + ] +) + +FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( + [ + # Model for Seq2Seq Causal LM mapping + ("bart", "FlaxBartForConditionalGeneration"), + ("blenderbot", "FlaxBlenderbotForConditionalGeneration"), + ("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"), + ("encoder-decoder", "FlaxEncoderDecoderModel"), + ("longt5", "FlaxLongT5ForConditionalGeneration"), + ("marian", "FlaxMarianMTModel"), + ("mbart", "FlaxMBartForConditionalGeneration"), + ("mt5", "FlaxMT5ForConditionalGeneration"), + ("pegasus", "FlaxPegasusForConditionalGeneration"), + ("t5", "FlaxT5ForConditionalGeneration"), + ] +) + +FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Image-classification + ("beit", "FlaxBeitForImageClassification"), + ("dinov2", "FlaxDinov2ForImageClassification"), + ("regnet", "FlaxRegNetForImageClassification"), + ("resnet", "FlaxResNetForImageClassification"), + ("vit", "FlaxViTForImageClassification"), + ] +) + +FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( + [ + ("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"), + ] +) + +FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( + [ + # Model for Causal LM mapping + ("bart", "FlaxBartForCausalLM"), + ("bert", "FlaxBertForCausalLM"), + ("big_bird", "FlaxBigBirdForCausalLM"), + ("bloom", "FlaxBloomForCausalLM"), + ("electra", "FlaxElectraForCausalLM"), + ("gemma", "FlaxGemmaForCausalLM"), + ("gpt-sw3", "FlaxGPT2LMHeadModel"), + ("gpt2", "FlaxGPT2LMHeadModel"), + ("gpt_neo", "FlaxGPTNeoForCausalLM"), + ("gptj", "FlaxGPTJForCausalLM"), + ("llama", "FlaxLlamaForCausalLM"), + ("mistral", "FlaxMistralForCausalLM"), + ("opt", "FlaxOPTForCausalLM"), + ("roberta", "FlaxRobertaForCausalLM"), + ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"), + ("xglm", "FlaxXGLMForCausalLM"), + ("xlm-roberta", "FlaxXLMRobertaForCausalLM"), + ] +) + +FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Sequence Classification mapping + ("albert", "FlaxAlbertForSequenceClassification"), + ("bart", "FlaxBartForSequenceClassification"), + ("bert", "FlaxBertForSequenceClassification"), + ("big_bird", "FlaxBigBirdForSequenceClassification"), + ("distilbert", "FlaxDistilBertForSequenceClassification"), + ("electra", "FlaxElectraForSequenceClassification"), + ("mbart", "FlaxMBartForSequenceClassification"), + ("roberta", "FlaxRobertaForSequenceClassification"), + ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForSequenceClassification"), + ("roformer", "FlaxRoFormerForSequenceClassification"), + ("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"), + ] +) + +FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( + [ + # Model for Question Answering mapping + ("albert", "FlaxAlbertForQuestionAnswering"), + ("bart", "FlaxBartForQuestionAnswering"), + ("bert", "FlaxBertForQuestionAnswering"), + ("big_bird", "FlaxBigBirdForQuestionAnswering"), + ("distilbert", "FlaxDistilBertForQuestionAnswering"), + ("electra", "FlaxElectraForQuestionAnswering"), + ("mbart", "FlaxMBartForQuestionAnswering"), + ("roberta", "FlaxRobertaForQuestionAnswering"), + ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForQuestionAnswering"), + ("roformer", "FlaxRoFormerForQuestionAnswering"), + ("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"), + ] +) + +FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Token Classification mapping + ("albert", "FlaxAlbertForTokenClassification"), + ("bert", "FlaxBertForTokenClassification"), + ("big_bird", "FlaxBigBirdForTokenClassification"), + ("distilbert", "FlaxDistilBertForTokenClassification"), + ("electra", "FlaxElectraForTokenClassification"), + ("roberta", "FlaxRobertaForTokenClassification"), + ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForTokenClassification"), + ("roformer", "FlaxRoFormerForTokenClassification"), + ("xlm-roberta", "FlaxXLMRobertaForTokenClassification"), + ] +) + +FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( + [ + # Model for Multiple Choice mapping + ("albert", "FlaxAlbertForMultipleChoice"), + ("bert", "FlaxBertForMultipleChoice"), + ("big_bird", "FlaxBigBirdForMultipleChoice"), + ("distilbert", "FlaxDistilBertForMultipleChoice"), + ("electra", "FlaxElectraForMultipleChoice"), + ("roberta", "FlaxRobertaForMultipleChoice"), + ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMultipleChoice"), + ("roformer", "FlaxRoFormerForMultipleChoice"), + ("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"), + ] +) + +FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( + [ + ("bert", "FlaxBertForNextSentencePrediction"), + ] +) + +FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( + [ + ("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"), + ("whisper", "FlaxWhisperForConditionalGeneration"), + ] +) + +FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + ("whisper", "FlaxWhisperForAudioClassification"), + ] +) + +FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES) +FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES) +FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES) +FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES +) +FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES +) +FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) +FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES +) +FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES +) +FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES +) +FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES +) +FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES +) +FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES +) +FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES +) + + +class FlaxAutoModel(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_MAPPING + + +FlaxAutoModel = auto_class_update(FlaxAutoModel) + + +class FlaxAutoModelForPreTraining(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING + + +FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining") + + +class FlaxAutoModelForCausalLM(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING + + +FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling") + + +class FlaxAutoModelForMaskedLM(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING + + +FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling") + + +class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING + + +FlaxAutoModelForSeq2SeqLM = auto_class_update( + FlaxAutoModelForSeq2SeqLM, + head_doc="sequence-to-sequence language modeling", + checkpoint_for_example="google-t5/t5-base", +) + + +class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING + + +FlaxAutoModelForSequenceClassification = auto_class_update( + FlaxAutoModelForSequenceClassification, head_doc="sequence classification" +) + + +class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING + + +FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering") + + +class FlaxAutoModelForTokenClassification(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING + + +FlaxAutoModelForTokenClassification = auto_class_update( + FlaxAutoModelForTokenClassification, head_doc="token classification" +) + + +class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING + + +FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice") + + +class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING + + +FlaxAutoModelForNextSentencePrediction = auto_class_update( + FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction" +) + + +class FlaxAutoModelForImageClassification(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING + + +FlaxAutoModelForImageClassification = auto_class_update( + FlaxAutoModelForImageClassification, head_doc="image classification" +) + + +class FlaxAutoModelForVision2Seq(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING + + +FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling") + + +class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass): + _model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING + + +FlaxAutoModelForSpeechSeq2Seq = auto_class_update( + FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" +) + +__all__ = [ + "FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", + "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING", + "FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", + "FLAX_MODEL_FOR_MASKED_LM_MAPPING", + "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", + "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", + "FLAX_MODEL_FOR_PRETRAINING_MAPPING", + "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", + "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", + "FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", + "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", + "FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING", + "FLAX_MODEL_MAPPING", + "FlaxAutoModel", + "FlaxAutoModelForCausalLM", + "FlaxAutoModelForImageClassification", + "FlaxAutoModelForMaskedLM", + "FlaxAutoModelForMultipleChoice", + "FlaxAutoModelForNextSentencePrediction", + "FlaxAutoModelForPreTraining", + "FlaxAutoModelForQuestionAnswering", + "FlaxAutoModelForSeq2SeqLM", + "FlaxAutoModelForSequenceClassification", + "FlaxAutoModelForSpeechSeq2Seq", + "FlaxAutoModelForTokenClassification", + "FlaxAutoModelForVision2Seq", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_tf_auto.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_tf_auto.py new file mode 100644 index 0000000000000000000000000000000000000000..cf39f4d7c9c40bd87a8e4c5e3037e2cbe3574a29 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_tf_auto.py @@ -0,0 +1,776 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Auto Model class.""" + +import warnings +from collections import OrderedDict + +from ...utils import logging +from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update +from .configuration_auto import CONFIG_MAPPING_NAMES + + +logger = logging.get_logger(__name__) + + +TF_MODEL_MAPPING_NAMES = OrderedDict( + [ + # Base model mapping + ("albert", "TFAlbertModel"), + ("bart", "TFBartModel"), + ("bert", "TFBertModel"), + ("blenderbot", "TFBlenderbotModel"), + ("blenderbot-small", "TFBlenderbotSmallModel"), + ("blip", "TFBlipModel"), + ("camembert", "TFCamembertModel"), + ("clip", "TFCLIPModel"), + ("convbert", "TFConvBertModel"), + ("convnext", "TFConvNextModel"), + ("convnextv2", "TFConvNextV2Model"), + ("ctrl", "TFCTRLModel"), + ("cvt", "TFCvtModel"), + ("data2vec-vision", "TFData2VecVisionModel"), + ("deberta", "TFDebertaModel"), + ("deberta-v2", "TFDebertaV2Model"), + ("deit", "TFDeiTModel"), + ("distilbert", "TFDistilBertModel"), + ("dpr", "TFDPRQuestionEncoder"), + ("efficientformer", "TFEfficientFormerModel"), + ("electra", "TFElectraModel"), + ("esm", "TFEsmModel"), + ("flaubert", "TFFlaubertModel"), + ("funnel", ("TFFunnelModel", "TFFunnelBaseModel")), + ("gpt-sw3", "TFGPT2Model"), + ("gpt2", "TFGPT2Model"), + ("gptj", "TFGPTJModel"), + ("groupvit", "TFGroupViTModel"), + ("hubert", "TFHubertModel"), + ("idefics", "TFIdeficsModel"), + ("layoutlm", "TFLayoutLMModel"), + ("layoutlmv3", "TFLayoutLMv3Model"), + ("led", "TFLEDModel"), + ("longformer", "TFLongformerModel"), + ("lxmert", "TFLxmertModel"), + ("marian", "TFMarianModel"), + ("mbart", "TFMBartModel"), + ("mistral", "TFMistralModel"), + ("mobilebert", "TFMobileBertModel"), + ("mobilevit", "TFMobileViTModel"), + ("mpnet", "TFMPNetModel"), + ("mt5", "TFMT5Model"), + ("openai-gpt", "TFOpenAIGPTModel"), + ("opt", "TFOPTModel"), + ("pegasus", "TFPegasusModel"), + ("regnet", "TFRegNetModel"), + ("rembert", "TFRemBertModel"), + ("resnet", "TFResNetModel"), + ("roberta", "TFRobertaModel"), + ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"), + ("roformer", "TFRoFormerModel"), + ("sam", "TFSamModel"), + ("sam_vision_model", "TFSamVisionModel"), + ("segformer", "TFSegformerModel"), + ("speech_to_text", "TFSpeech2TextModel"), + ("swiftformer", "TFSwiftFormerModel"), + ("swin", "TFSwinModel"), + ("t5", "TFT5Model"), + ("tapas", "TFTapasModel"), + ("transfo-xl", "TFTransfoXLModel"), + ("vision-text-dual-encoder", "TFVisionTextDualEncoderModel"), + ("vit", "TFViTModel"), + ("vit_mae", "TFViTMAEModel"), + ("wav2vec2", "TFWav2Vec2Model"), + ("whisper", "TFWhisperModel"), + ("xglm", "TFXGLMModel"), + ("xlm", "TFXLMModel"), + ("xlm-roberta", "TFXLMRobertaModel"), + ("xlnet", "TFXLNetModel"), + ] +) + +TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( + [ + # Model for pre-training mapping + ("albert", "TFAlbertForPreTraining"), + ("bart", "TFBartForConditionalGeneration"), + ("bert", "TFBertForPreTraining"), + ("camembert", "TFCamembertForMaskedLM"), + ("ctrl", "TFCTRLLMHeadModel"), + ("distilbert", "TFDistilBertForMaskedLM"), + ("electra", "TFElectraForPreTraining"), + ("flaubert", "TFFlaubertWithLMHeadModel"), + ("funnel", "TFFunnelForPreTraining"), + ("gpt-sw3", "TFGPT2LMHeadModel"), + ("gpt2", "TFGPT2LMHeadModel"), + ("idefics", "TFIdeficsForVisionText2Text"), + ("layoutlm", "TFLayoutLMForMaskedLM"), + ("lxmert", "TFLxmertForPreTraining"), + ("mobilebert", "TFMobileBertForPreTraining"), + ("mpnet", "TFMPNetForMaskedLM"), + ("openai-gpt", "TFOpenAIGPTLMHeadModel"), + ("roberta", "TFRobertaForMaskedLM"), + ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"), + ("t5", "TFT5ForConditionalGeneration"), + ("tapas", "TFTapasForMaskedLM"), + ("transfo-xl", "TFTransfoXLLMHeadModel"), + ("vit_mae", "TFViTMAEForPreTraining"), + ("xlm", "TFXLMWithLMHeadModel"), + ("xlm-roberta", "TFXLMRobertaForMaskedLM"), + ("xlnet", "TFXLNetLMHeadModel"), + ] +) + +TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( + [ + # Model with LM heads mapping + ("albert", "TFAlbertForMaskedLM"), + ("bart", "TFBartForConditionalGeneration"), + ("bert", "TFBertForMaskedLM"), + ("camembert", "TFCamembertForMaskedLM"), + ("convbert", "TFConvBertForMaskedLM"), + ("ctrl", "TFCTRLLMHeadModel"), + ("distilbert", "TFDistilBertForMaskedLM"), + ("electra", "TFElectraForMaskedLM"), + ("esm", "TFEsmForMaskedLM"), + ("flaubert", "TFFlaubertWithLMHeadModel"), + ("funnel", "TFFunnelForMaskedLM"), + ("gpt-sw3", "TFGPT2LMHeadModel"), + ("gpt2", "TFGPT2LMHeadModel"), + ("gptj", "TFGPTJForCausalLM"), + ("layoutlm", "TFLayoutLMForMaskedLM"), + ("led", "TFLEDForConditionalGeneration"), + ("longformer", "TFLongformerForMaskedLM"), + ("marian", "TFMarianMTModel"), + ("mobilebert", "TFMobileBertForMaskedLM"), + ("mpnet", "TFMPNetForMaskedLM"), + ("openai-gpt", "TFOpenAIGPTLMHeadModel"), + ("rembert", "TFRemBertForMaskedLM"), + ("roberta", "TFRobertaForMaskedLM"), + ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"), + ("roformer", "TFRoFormerForMaskedLM"), + ("speech_to_text", "TFSpeech2TextForConditionalGeneration"), + ("t5", "TFT5ForConditionalGeneration"), + ("tapas", "TFTapasForMaskedLM"), + ("transfo-xl", "TFTransfoXLLMHeadModel"), + ("whisper", "TFWhisperForConditionalGeneration"), + ("xlm", "TFXLMWithLMHeadModel"), + ("xlm-roberta", "TFXLMRobertaForMaskedLM"), + ("xlnet", "TFXLNetLMHeadModel"), + ] +) + +TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( + [ + # Model for Causal LM mapping + ("bert", "TFBertLMHeadModel"), + ("camembert", "TFCamembertForCausalLM"), + ("ctrl", "TFCTRLLMHeadModel"), + ("gpt-sw3", "TFGPT2LMHeadModel"), + ("gpt2", "TFGPT2LMHeadModel"), + ("gptj", "TFGPTJForCausalLM"), + ("mistral", "TFMistralForCausalLM"), + ("openai-gpt", "TFOpenAIGPTLMHeadModel"), + ("opt", "TFOPTForCausalLM"), + ("rembert", "TFRemBertForCausalLM"), + ("roberta", "TFRobertaForCausalLM"), + ("roberta-prelayernorm", "TFRobertaPreLayerNormForCausalLM"), + ("roformer", "TFRoFormerForCausalLM"), + ("transfo-xl", "TFTransfoXLLMHeadModel"), + ("xglm", "TFXGLMForCausalLM"), + ("xlm", "TFXLMWithLMHeadModel"), + ("xlm-roberta", "TFXLMRobertaForCausalLM"), + ("xlnet", "TFXLNetLMHeadModel"), + ] +) + +TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( + [ + ("deit", "TFDeiTForMaskedImageModeling"), + ("swin", "TFSwinForMaskedImageModeling"), + ] +) + +TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Image-classsification + ("convnext", "TFConvNextForImageClassification"), + ("convnextv2", "TFConvNextV2ForImageClassification"), + ("cvt", "TFCvtForImageClassification"), + ("data2vec-vision", "TFData2VecVisionForImageClassification"), + ("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")), + ( + "efficientformer", + ("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"), + ), + ("mobilevit", "TFMobileViTForImageClassification"), + ("regnet", "TFRegNetForImageClassification"), + ("resnet", "TFResNetForImageClassification"), + ("segformer", "TFSegformerForImageClassification"), + ("swiftformer", "TFSwiftFormerForImageClassification"), + ("swin", "TFSwinForImageClassification"), + ("vit", "TFViTForImageClassification"), + ] +) + + +TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Zero Shot Image Classification mapping + ("blip", "TFBlipModel"), + ("clip", "TFCLIPModel"), + ] +) + + +TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Semantic Segmentation mapping + ("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"), + ("mobilevit", "TFMobileViTForSemanticSegmentation"), + ("segformer", "TFSegformerForSemanticSegmentation"), + ] +) + +TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( + [ + ("blip", "TFBlipForConditionalGeneration"), + ("vision-encoder-decoder", "TFVisionEncoderDecoderModel"), + ] +) + +TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( + [ + # Model for Masked LM mapping + ("albert", "TFAlbertForMaskedLM"), + ("bert", "TFBertForMaskedLM"), + ("camembert", "TFCamembertForMaskedLM"), + ("convbert", "TFConvBertForMaskedLM"), + ("deberta", "TFDebertaForMaskedLM"), + ("deberta-v2", "TFDebertaV2ForMaskedLM"), + ("distilbert", "TFDistilBertForMaskedLM"), + ("electra", "TFElectraForMaskedLM"), + ("esm", "TFEsmForMaskedLM"), + ("flaubert", "TFFlaubertWithLMHeadModel"), + ("funnel", "TFFunnelForMaskedLM"), + ("layoutlm", "TFLayoutLMForMaskedLM"), + ("longformer", "TFLongformerForMaskedLM"), + ("mobilebert", "TFMobileBertForMaskedLM"), + ("mpnet", "TFMPNetForMaskedLM"), + ("rembert", "TFRemBertForMaskedLM"), + ("roberta", "TFRobertaForMaskedLM"), + ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"), + ("roformer", "TFRoFormerForMaskedLM"), + ("tapas", "TFTapasForMaskedLM"), + ("xlm", "TFXLMWithLMHeadModel"), + ("xlm-roberta", "TFXLMRobertaForMaskedLM"), + ] +) + +TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( + [ + # Model for Seq2Seq Causal LM mapping + ("bart", "TFBartForConditionalGeneration"), + ("blenderbot", "TFBlenderbotForConditionalGeneration"), + ("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"), + ("encoder-decoder", "TFEncoderDecoderModel"), + ("led", "TFLEDForConditionalGeneration"), + ("marian", "TFMarianMTModel"), + ("mbart", "TFMBartForConditionalGeneration"), + ("mt5", "TFMT5ForConditionalGeneration"), + ("pegasus", "TFPegasusForConditionalGeneration"), + ("t5", "TFT5ForConditionalGeneration"), + ] +) + +TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( + [ + ("speech_to_text", "TFSpeech2TextForConditionalGeneration"), + ("whisper", "TFWhisperForConditionalGeneration"), + ] +) + +TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Sequence Classification mapping + ("albert", "TFAlbertForSequenceClassification"), + ("bart", "TFBartForSequenceClassification"), + ("bert", "TFBertForSequenceClassification"), + ("camembert", "TFCamembertForSequenceClassification"), + ("convbert", "TFConvBertForSequenceClassification"), + ("ctrl", "TFCTRLForSequenceClassification"), + ("deberta", "TFDebertaForSequenceClassification"), + ("deberta-v2", "TFDebertaV2ForSequenceClassification"), + ("distilbert", "TFDistilBertForSequenceClassification"), + ("electra", "TFElectraForSequenceClassification"), + ("esm", "TFEsmForSequenceClassification"), + ("flaubert", "TFFlaubertForSequenceClassification"), + ("funnel", "TFFunnelForSequenceClassification"), + ("gpt-sw3", "TFGPT2ForSequenceClassification"), + ("gpt2", "TFGPT2ForSequenceClassification"), + ("gptj", "TFGPTJForSequenceClassification"), + ("layoutlm", "TFLayoutLMForSequenceClassification"), + ("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"), + ("longformer", "TFLongformerForSequenceClassification"), + ("mistral", "TFMistralForSequenceClassification"), + ("mobilebert", "TFMobileBertForSequenceClassification"), + ("mpnet", "TFMPNetForSequenceClassification"), + ("openai-gpt", "TFOpenAIGPTForSequenceClassification"), + ("rembert", "TFRemBertForSequenceClassification"), + ("roberta", "TFRobertaForSequenceClassification"), + ("roberta-prelayernorm", "TFRobertaPreLayerNormForSequenceClassification"), + ("roformer", "TFRoFormerForSequenceClassification"), + ("tapas", "TFTapasForSequenceClassification"), + ("transfo-xl", "TFTransfoXLForSequenceClassification"), + ("xlm", "TFXLMForSequenceClassification"), + ("xlm-roberta", "TFXLMRobertaForSequenceClassification"), + ("xlnet", "TFXLNetForSequenceClassification"), + ] +) + +TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( + [ + # Model for Question Answering mapping + ("albert", "TFAlbertForQuestionAnswering"), + ("bert", "TFBertForQuestionAnswering"), + ("camembert", "TFCamembertForQuestionAnswering"), + ("convbert", "TFConvBertForQuestionAnswering"), + ("deberta", "TFDebertaForQuestionAnswering"), + ("deberta-v2", "TFDebertaV2ForQuestionAnswering"), + ("distilbert", "TFDistilBertForQuestionAnswering"), + ("electra", "TFElectraForQuestionAnswering"), + ("flaubert", "TFFlaubertForQuestionAnsweringSimple"), + ("funnel", "TFFunnelForQuestionAnswering"), + ("gptj", "TFGPTJForQuestionAnswering"), + ("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"), + ("longformer", "TFLongformerForQuestionAnswering"), + ("mobilebert", "TFMobileBertForQuestionAnswering"), + ("mpnet", "TFMPNetForQuestionAnswering"), + ("rembert", "TFRemBertForQuestionAnswering"), + ("roberta", "TFRobertaForQuestionAnswering"), + ("roberta-prelayernorm", "TFRobertaPreLayerNormForQuestionAnswering"), + ("roformer", "TFRoFormerForQuestionAnswering"), + ("xlm", "TFXLMForQuestionAnsweringSimple"), + ("xlm-roberta", "TFXLMRobertaForQuestionAnswering"), + ("xlnet", "TFXLNetForQuestionAnsweringSimple"), + ] +) +TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")]) + +TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( + [ + ("layoutlm", "TFLayoutLMForQuestionAnswering"), + ("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"), + ] +) + + +TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( + [ + # Model for Table Question Answering mapping + ("tapas", "TFTapasForQuestionAnswering"), + ] +) + +TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Token Classification mapping + ("albert", "TFAlbertForTokenClassification"), + ("bert", "TFBertForTokenClassification"), + ("camembert", "TFCamembertForTokenClassification"), + ("convbert", "TFConvBertForTokenClassification"), + ("deberta", "TFDebertaForTokenClassification"), + ("deberta-v2", "TFDebertaV2ForTokenClassification"), + ("distilbert", "TFDistilBertForTokenClassification"), + ("electra", "TFElectraForTokenClassification"), + ("esm", "TFEsmForTokenClassification"), + ("flaubert", "TFFlaubertForTokenClassification"), + ("funnel", "TFFunnelForTokenClassification"), + ("layoutlm", "TFLayoutLMForTokenClassification"), + ("layoutlmv3", "TFLayoutLMv3ForTokenClassification"), + ("longformer", "TFLongformerForTokenClassification"), + ("mobilebert", "TFMobileBertForTokenClassification"), + ("mpnet", "TFMPNetForTokenClassification"), + ("rembert", "TFRemBertForTokenClassification"), + ("roberta", "TFRobertaForTokenClassification"), + ("roberta-prelayernorm", "TFRobertaPreLayerNormForTokenClassification"), + ("roformer", "TFRoFormerForTokenClassification"), + ("xlm", "TFXLMForTokenClassification"), + ("xlm-roberta", "TFXLMRobertaForTokenClassification"), + ("xlnet", "TFXLNetForTokenClassification"), + ] +) + +TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( + [ + # Model for Multiple Choice mapping + ("albert", "TFAlbertForMultipleChoice"), + ("bert", "TFBertForMultipleChoice"), + ("camembert", "TFCamembertForMultipleChoice"), + ("convbert", "TFConvBertForMultipleChoice"), + ("deberta-v2", "TFDebertaV2ForMultipleChoice"), + ("distilbert", "TFDistilBertForMultipleChoice"), + ("electra", "TFElectraForMultipleChoice"), + ("flaubert", "TFFlaubertForMultipleChoice"), + ("funnel", "TFFunnelForMultipleChoice"), + ("longformer", "TFLongformerForMultipleChoice"), + ("mobilebert", "TFMobileBertForMultipleChoice"), + ("mpnet", "TFMPNetForMultipleChoice"), + ("rembert", "TFRemBertForMultipleChoice"), + ("roberta", "TFRobertaForMultipleChoice"), + ("roberta-prelayernorm", "TFRobertaPreLayerNormForMultipleChoice"), + ("roformer", "TFRoFormerForMultipleChoice"), + ("xlm", "TFXLMForMultipleChoice"), + ("xlm-roberta", "TFXLMRobertaForMultipleChoice"), + ("xlnet", "TFXLNetForMultipleChoice"), + ] +) + +TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( + [ + ("bert", "TFBertForNextSentencePrediction"), + ("mobilebert", "TFMobileBertForNextSentencePrediction"), + ] +) +TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( + [ + ("sam", "TFSamModel"), + ] +) +TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict( + [ + ("albert", "TFAlbertModel"), + ("bert", "TFBertModel"), + ("convbert", "TFConvBertModel"), + ("deberta", "TFDebertaModel"), + ("deberta-v2", "TFDebertaV2Model"), + ("distilbert", "TFDistilBertModel"), + ("electra", "TFElectraModel"), + ("flaubert", "TFFlaubertModel"), + ("longformer", "TFLongformerModel"), + ("mobilebert", "TFMobileBertModel"), + ("mt5", "TFMT5EncoderModel"), + ("rembert", "TFRemBertModel"), + ("roberta", "TFRobertaModel"), + ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"), + ("roformer", "TFRoFormerModel"), + ("t5", "TFT5EncoderModel"), + ("xlm", "TFXLMModel"), + ("xlm-roberta", "TFXLMRobertaModel"), + ] +) + +TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES) +TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES) +TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES) +TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES +) +TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES +) +TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES +) +TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES +) +TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) +TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES) +TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES +) +TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES +) +TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES +) +TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES +) +TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES +) +TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES +) +TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES +) +TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES +) +TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES +) +TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES +) + +TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES +) + +TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES) + + +class TFAutoModelForMaskGeneration(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING + + +class TFAutoModelForTextEncoding(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING + + +class TFAutoModel(_BaseAutoModelClass): + _model_mapping = TF_MODEL_MAPPING + + +TFAutoModel = auto_class_update(TFAutoModel) + + +class TFAutoModelForAudioClassification(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING + + +TFAutoModelForAudioClassification = auto_class_update( + TFAutoModelForAudioClassification, head_doc="audio classification" +) + + +class TFAutoModelForPreTraining(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING + + +TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining") + + +# Private on purpose, the public class will add the deprecation warnings. +class _TFAutoModelWithLMHead(_BaseAutoModelClass): + _model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING + + +_TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling") + + +class TFAutoModelForCausalLM(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING + + +TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling") + + +class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING + + +TFAutoModelForMaskedImageModeling = auto_class_update( + TFAutoModelForMaskedImageModeling, head_doc="masked image modeling" +) + + +class TFAutoModelForImageClassification(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING + + +TFAutoModelForImageClassification = auto_class_update( + TFAutoModelForImageClassification, head_doc="image classification" +) + + +class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING + + +TFAutoModelForZeroShotImageClassification = auto_class_update( + TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification" +) + + +class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING + + +TFAutoModelForSemanticSegmentation = auto_class_update( + TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation" +) + + +class TFAutoModelForVision2Seq(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING + + +TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling") + + +class TFAutoModelForMaskedLM(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING + + +TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling") + + +class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING + + +TFAutoModelForSeq2SeqLM = auto_class_update( + TFAutoModelForSeq2SeqLM, + head_doc="sequence-to-sequence language modeling", + checkpoint_for_example="google-t5/t5-base", +) + + +class TFAutoModelForSequenceClassification(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING + + +TFAutoModelForSequenceClassification = auto_class_update( + TFAutoModelForSequenceClassification, head_doc="sequence classification" +) + + +class TFAutoModelForQuestionAnswering(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING + + +TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering") + + +class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING + + +TFAutoModelForDocumentQuestionAnswering = auto_class_update( + TFAutoModelForDocumentQuestionAnswering, + head_doc="document question answering", + checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3', +) + + +class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING + + +TFAutoModelForTableQuestionAnswering = auto_class_update( + TFAutoModelForTableQuestionAnswering, + head_doc="table question answering", + checkpoint_for_example="google/tapas-base-finetuned-wtq", +) + + +class TFAutoModelForTokenClassification(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING + + +TFAutoModelForTokenClassification = auto_class_update( + TFAutoModelForTokenClassification, head_doc="token classification" +) + + +class TFAutoModelForMultipleChoice(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING + + +TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice") + + +class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING + + +TFAutoModelForNextSentencePrediction = auto_class_update( + TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction" +) + + +class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING + + +TFAutoModelForSpeechSeq2Seq = auto_class_update( + TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" +) + + +class TFAutoModelWithLMHead(_TFAutoModelWithLMHead): + @classmethod + def from_config(cls, config): + warnings.warn( + "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use" + " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models" + " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.", + FutureWarning, + ) + return super().from_config(config) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + warnings.warn( + "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use" + " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models" + " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.", + FutureWarning, + ) + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + +__all__ = [ + "TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", + "TF_MODEL_FOR_CAUSAL_LM_MAPPING", + "TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", + "TF_MODEL_FOR_MASK_GENERATION_MAPPING", + "TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING", + "TF_MODEL_FOR_MASKED_LM_MAPPING", + "TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", + "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", + "TF_MODEL_FOR_PRETRAINING_MAPPING", + "TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", + "TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING", + "TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", + "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", + "TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", + "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", + "TF_MODEL_FOR_TEXT_ENCODING_MAPPING", + "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", + "TF_MODEL_FOR_VISION_2_SEQ_MAPPING", + "TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING", + "TF_MODEL_MAPPING", + "TF_MODEL_WITH_LM_HEAD_MAPPING", + "TFAutoModel", + "TFAutoModelForAudioClassification", + "TFAutoModelForCausalLM", + "TFAutoModelForImageClassification", + "TFAutoModelForMaskedImageModeling", + "TFAutoModelForMaskedLM", + "TFAutoModelForMaskGeneration", + "TFAutoModelForMultipleChoice", + "TFAutoModelForNextSentencePrediction", + "TFAutoModelForPreTraining", + "TFAutoModelForDocumentQuestionAnswering", + "TFAutoModelForQuestionAnswering", + "TFAutoModelForSemanticSegmentation", + "TFAutoModelForSeq2SeqLM", + "TFAutoModelForSequenceClassification", + "TFAutoModelForSpeechSeq2Seq", + "TFAutoModelForTableQuestionAnswering", + "TFAutoModelForTextEncoding", + "TFAutoModelForTokenClassification", + "TFAutoModelForVision2Seq", + "TFAutoModelForZeroShotImageClassification", + "TFAutoModelWithLMHead", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/processing_auto.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/processing_auto.py new file mode 100644 index 0000000000000000000000000000000000000000..11862a5896b94be7d1d247f022026bf64088987d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/processing_auto.py @@ -0,0 +1,443 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""AutoProcessor class.""" + +import importlib +import inspect +import json +import warnings +from collections import OrderedDict + +# Build the list of all feature extractors +from ...configuration_utils import PretrainedConfig +from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from ...feature_extraction_utils import FeatureExtractionMixin +from ...image_processing_utils import ImageProcessingMixin +from ...processing_utils import ProcessorMixin +from ...tokenization_utils import TOKENIZER_CONFIG_FILE +from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, VIDEO_PROCESSOR_NAME, cached_file, logging +from ...video_processing_utils import BaseVideoProcessor +from .auto_factory import _LazyAutoMapping +from .configuration_auto import ( + CONFIG_MAPPING_NAMES, + AutoConfig, + model_type_to_module_name, + replace_list_option_in_docstrings, +) +from .feature_extraction_auto import AutoFeatureExtractor +from .image_processing_auto import AutoImageProcessor +from .tokenization_auto import AutoTokenizer + + +logger = logging.get_logger(__name__) + +PROCESSOR_MAPPING_NAMES = OrderedDict( + [ + ("aimv2", "CLIPProcessor"), + ("align", "AlignProcessor"), + ("altclip", "AltCLIPProcessor"), + ("aria", "AriaProcessor"), + ("aya_vision", "AyaVisionProcessor"), + ("bark", "BarkProcessor"), + ("blip", "BlipProcessor"), + ("blip-2", "Blip2Processor"), + ("bridgetower", "BridgeTowerProcessor"), + ("chameleon", "ChameleonProcessor"), + ("chinese_clip", "ChineseCLIPProcessor"), + ("clap", "ClapProcessor"), + ("clip", "CLIPProcessor"), + ("clipseg", "CLIPSegProcessor"), + ("clvp", "ClvpProcessor"), + ("cohere2_vision", "Cohere2VisionProcessor"), + ("colpali", "ColPaliProcessor"), + ("colqwen2", "ColQwen2Processor"), + ("deepseek_vl", "DeepseekVLProcessor"), + ("deepseek_vl_hybrid", "DeepseekVLHybridProcessor"), + ("dia", "DiaProcessor"), + ("edgetam", "Sam2Processor"), + ("emu3", "Emu3Processor"), + ("evolla", "EvollaProcessor"), + ("flava", "FlavaProcessor"), + ("florence2", "Florence2Processor"), + ("fuyu", "FuyuProcessor"), + ("gemma3", "Gemma3Processor"), + ("gemma3n", "Gemma3nProcessor"), + ("git", "GitProcessor"), + ("glm4v", "Glm4vProcessor"), + ("glm4v_moe", "Glm4vProcessor"), + ("got_ocr2", "GotOcr2Processor"), + ("granite_speech", "GraniteSpeechProcessor"), + ("grounding-dino", "GroundingDinoProcessor"), + ("groupvit", "CLIPProcessor"), + ("hubert", "Wav2Vec2Processor"), + ("idefics", "IdeficsProcessor"), + ("idefics2", "Idefics2Processor"), + ("idefics3", "Idefics3Processor"), + ("instructblip", "InstructBlipProcessor"), + ("instructblipvideo", "InstructBlipVideoProcessor"), + ("internvl", "InternVLProcessor"), + ("janus", "JanusProcessor"), + ("kosmos-2", "Kosmos2Processor"), + ("kosmos-2.5", "Kosmos2_5Processor"), + ("kyutai_speech_to_text", "KyutaiSpeechToTextProcessor"), + ("layoutlmv2", "LayoutLMv2Processor"), + ("layoutlmv3", "LayoutLMv3Processor"), + ("lfm2_vl", "Lfm2VlProcessor"), + ("llama4", "Llama4Processor"), + ("llava", "LlavaProcessor"), + ("llava_next", "LlavaNextProcessor"), + ("llava_next_video", "LlavaNextVideoProcessor"), + ("llava_onevision", "LlavaOnevisionProcessor"), + ("markuplm", "MarkupLMProcessor"), + ("mctct", "MCTCTProcessor"), + ("metaclip_2", "CLIPProcessor"), + ("mgp-str", "MgpstrProcessor"), + ("mistral3", "PixtralProcessor"), + ("mllama", "MllamaProcessor"), + ("mm-grounding-dino", "GroundingDinoProcessor"), + ("moonshine", "Wav2Vec2Processor"), + ("oneformer", "OneFormerProcessor"), + ("ovis2", "Ovis2Processor"), + ("owlv2", "Owlv2Processor"), + ("owlvit", "OwlViTProcessor"), + ("paligemma", "PaliGemmaProcessor"), + ("perception_lm", "PerceptionLMProcessor"), + ("phi4_multimodal", "Phi4MultimodalProcessor"), + ("pix2struct", "Pix2StructProcessor"), + ("pixtral", "PixtralProcessor"), + ("pop2piano", "Pop2PianoProcessor"), + ("qwen2_5_omni", "Qwen2_5OmniProcessor"), + ("qwen2_5_vl", "Qwen2_5_VLProcessor"), + ("qwen2_audio", "Qwen2AudioProcessor"), + ("qwen2_vl", "Qwen2VLProcessor"), + ("qwen3_omni_moe", "Qwen3OmniMoeProcessor"), + ("qwen3_vl", "Qwen3VLProcessor"), + ("qwen3_vl_moe", "Qwen3VLProcessor"), + ("sam", "SamProcessor"), + ("sam2", "Sam2Processor"), + ("sam_hq", "SamHQProcessor"), + ("seamless_m4t", "SeamlessM4TProcessor"), + ("sew", "Wav2Vec2Processor"), + ("sew-d", "Wav2Vec2Processor"), + ("shieldgemma2", "ShieldGemma2Processor"), + ("siglip", "SiglipProcessor"), + ("siglip2", "Siglip2Processor"), + ("smolvlm", "SmolVLMProcessor"), + ("speech_to_text", "Speech2TextProcessor"), + ("speech_to_text_2", "Speech2Text2Processor"), + ("speecht5", "SpeechT5Processor"), + ("trocr", "TrOCRProcessor"), + ("tvlt", "TvltProcessor"), + ("tvp", "TvpProcessor"), + ("udop", "UdopProcessor"), + ("unispeech", "Wav2Vec2Processor"), + ("unispeech-sat", "Wav2Vec2Processor"), + ("video_llava", "VideoLlavaProcessor"), + ("vilt", "ViltProcessor"), + ("vipllava", "LlavaProcessor"), + ("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"), + ("voxtral", "VoxtralProcessor"), + ("wav2vec2", "Wav2Vec2Processor"), + ("wav2vec2-bert", "Wav2Vec2Processor"), + ("wav2vec2-conformer", "Wav2Vec2Processor"), + ("wavlm", "Wav2Vec2Processor"), + ("whisper", "WhisperProcessor"), + ("xclip", "XCLIPProcessor"), + ] +) + +PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, PROCESSOR_MAPPING_NAMES) + + +def processor_class_from_name(class_name: str): + for module_name, processors in PROCESSOR_MAPPING_NAMES.items(): + if class_name in processors: + module_name = model_type_to_module_name(module_name) + + module = importlib.import_module(f".{module_name}", "transformers.models") + try: + return getattr(module, class_name) + except AttributeError: + continue + + for processor in PROCESSOR_MAPPING._extra_content.values(): + if getattr(processor, "__name__", None) == class_name: + return processor + + # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main + # init and we return the proper dummy to get an appropriate error message. + main_module = importlib.import_module("transformers") + if hasattr(main_module, class_name): + return getattr(main_module, class_name) + + return None + + +class AutoProcessor: + r""" + This is a generic processor class that will be instantiated as one of the processor classes of the library when + created with the [`AutoProcessor.from_pretrained`] class method. + + This class cannot be instantiated directly using `__init__()` (throws an error). + """ + + def __init__(self): + raise OSError( + "AutoProcessor is designed to be instantiated " + "using the `AutoProcessor.from_pretrained(pretrained_model_name_or_path)` method." + ) + + @classmethod + @replace_list_option_in_docstrings(PROCESSOR_MAPPING_NAMES) + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" + Instantiate one of the processor classes of the library from a pretrained model vocabulary. + + The processor class to instantiate is selected based on the `model_type` property of the config object (either + passed as an argument or loaded from `pretrained_model_name_or_path` if possible): + + List options + + Params: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a processor files saved using the `save_pretrained()` method, + e.g., `./my_model_directory/`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model feature extractor should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the feature extractor files and override the cached versions + if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + If `False`, then this function returns just the final feature extractor object. If `True`, then this + functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary + consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of + `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + kwargs (`dict[str, Any]`, *optional*): + The values in kwargs of any keys which are feature extractor attributes will be used to override the + loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is + controlled by the `return_unused_kwargs` keyword parameter. + + + + Passing `token=True` is required when you want to use a private model. + + + + Examples: + + ```python + >>> from transformers import AutoProcessor + + >>> # Download processor from huggingface.co and cache. + >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h") + + >>> # If processor files are in a directory (e.g. processor was saved using *save_pretrained('./test/saved_model/')*) + >>> # processor = AutoProcessor.from_pretrained("./test/saved_model/") + ```""" + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token") is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + config = kwargs.pop("config", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + kwargs["_from_auto"] = True + + processor_class = None + processor_auto_map = None + + # First, let's see if we have a processor or preprocessor config. + # Filter the kwargs for `cached_file`. + cached_file_kwargs = {key: kwargs[key] for key in inspect.signature(cached_file).parameters if key in kwargs} + # We don't want to raise + cached_file_kwargs.update( + { + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_raise_exceptions_for_connection_errors": False, + } + ) + + # Let's start by checking whether the processor class is saved in a processor config + processor_config_file = cached_file(pretrained_model_name_or_path, PROCESSOR_NAME, **cached_file_kwargs) + if processor_config_file is not None: + config_dict, _ = ProcessorMixin.get_processor_dict(pretrained_model_name_or_path, **kwargs) + processor_class = config_dict.get("processor_class", None) + if "AutoProcessor" in config_dict.get("auto_map", {}): + processor_auto_map = config_dict["auto_map"]["AutoProcessor"] + + if processor_class is None: + # If not found, let's check whether the processor class is saved in an image processor config + preprocessor_config_file = cached_file( + pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs + ) + if preprocessor_config_file is not None: + config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs) + processor_class = config_dict.get("processor_class", None) + if "AutoProcessor" in config_dict.get("auto_map", {}): + processor_auto_map = config_dict["auto_map"]["AutoProcessor"] + + # Saved as video processor + if preprocessor_config_file is None: + preprocessor_config_file = cached_file( + pretrained_model_name_or_path, VIDEO_PROCESSOR_NAME, **cached_file_kwargs + ) + if preprocessor_config_file is not None: + config_dict, _ = BaseVideoProcessor.get_video_processor_dict( + pretrained_model_name_or_path, **kwargs + ) + processor_class = config_dict.get("processor_class", None) + if "AutoProcessor" in config_dict.get("auto_map", {}): + processor_auto_map = config_dict["auto_map"]["AutoProcessor"] + + # Saved as feature extractor + if preprocessor_config_file is None: + preprocessor_config_file = cached_file( + pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs + ) + if preprocessor_config_file is not None and processor_class is None: + config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict( + pretrained_model_name_or_path, **kwargs + ) + processor_class = config_dict.get("processor_class", None) + if "AutoProcessor" in config_dict.get("auto_map", {}): + processor_auto_map = config_dict["auto_map"]["AutoProcessor"] + + if processor_class is None: + # Next, let's check whether the processor class is saved in a tokenizer + tokenizer_config_file = cached_file( + pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **cached_file_kwargs + ) + if tokenizer_config_file is not None: + with open(tokenizer_config_file, encoding="utf-8") as reader: + config_dict = json.load(reader) + + processor_class = config_dict.get("processor_class", None) + if "AutoProcessor" in config_dict.get("auto_map", {}): + processor_auto_map = config_dict["auto_map"]["AutoProcessor"] + + if processor_class is None: + # Otherwise, load config, if it can be loaded. + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + + # And check if the config contains the processor class. + processor_class = getattr(config, "processor_class", None) + if hasattr(config, "auto_map") and "AutoProcessor" in config.auto_map: + processor_auto_map = config.auto_map["AutoProcessor"] + + if processor_class is not None: + processor_class = processor_class_from_name(processor_class) + + has_remote_code = processor_auto_map is not None + has_local_code = processor_class is not None or type(config) in PROCESSOR_MAPPING + if has_remote_code: + if "--" in processor_auto_map: + upstream_repo = processor_auto_map.split("--")[0] + else: + upstream_repo = None + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo + ) + + if has_remote_code and trust_remote_code: + processor_class = get_class_from_dynamic_module( + processor_auto_map, pretrained_model_name_or_path, **kwargs + ) + _ = kwargs.pop("code_revision", None) + processor_class.register_for_auto_class() + return processor_class.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + elif processor_class is not None: + return processor_class.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + # Last try: we use the PROCESSOR_MAPPING. + elif type(config) in PROCESSOR_MAPPING: + return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs) + + # At this stage, there doesn't seem to be a `Processor` class available for this model, so let's try a + # tokenizer. + try: + return AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + except Exception: + try: + return AutoImageProcessor.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + except Exception: + pass + + try: + return AutoFeatureExtractor.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + except Exception: + pass + + raise ValueError( + f"Unrecognized processing class in {pretrained_model_name_or_path}. Can't instantiate a processor, a " + "tokenizer, an image processor or a feature extractor for this model. Make sure the repository contains " + "the files of at least one of those processing classes." + ) + + @staticmethod + def register(config_class, processor_class, exist_ok=False): + """ + Register a new processor for this class. + + Args: + config_class ([`PretrainedConfig`]): + The configuration corresponding to the model to register. + processor_class ([`ProcessorMixin`]): The processor to register. + """ + PROCESSOR_MAPPING.register(config_class, processor_class, exist_ok=exist_ok) + + +__all__ = ["PROCESSOR_MAPPING", "AutoProcessor"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/tokenization_auto.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/tokenization_auto.py new file mode 100644 index 0000000000000000000000000000000000000000..163aba1cb12753ec7f9e487347b2a604df35fb28 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/tokenization_auto.py @@ -0,0 +1,1235 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Auto Tokenizer class.""" + +import importlib +import json +import os +import warnings +from collections import OrderedDict +from typing import Any, Optional, Union + +from transformers.utils.import_utils import is_mistral_common_available + +from ...configuration_utils import PretrainedConfig +from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from ...modeling_gguf_pytorch_utils import load_gguf_checkpoint +from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE +from ...utils import ( + cached_file, + extract_commit_hash, + is_g2p_en_available, + is_sentencepiece_available, + is_tokenizers_available, + logging, +) +from ..encoder_decoder import EncoderDecoderConfig +from .auto_factory import _LazyAutoMapping +from .configuration_auto import ( + CONFIG_MAPPING_NAMES, + AutoConfig, + config_class_to_model_type, + model_type_to_module_name, + replace_list_option_in_docstrings, +) + + +if is_tokenizers_available(): + from ...tokenization_utils_fast import PreTrainedTokenizerFast +else: + PreTrainedTokenizerFast = None + + +logger = logging.get_logger(__name__) + +# Explicit rather than inferred generics to significantly improves completion suggestion performance for language servers. +TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]]( + [ + ( + "aimv2", + ( + "CLIPTokenizer", + "CLIPTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "albert", + ( + "AlbertTokenizer" if is_sentencepiece_available() else None, + "AlbertTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("arcee", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("aria", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("aya_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)), + ("bark", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("bart", ("BartTokenizer", "BartTokenizerFast")), + ( + "barthez", + ( + "BarthezTokenizer" if is_sentencepiece_available() else None, + "BarthezTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("bartpho", ("BartphoTokenizer", None)), + ("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)), + ("bert-japanese", ("BertJapaneseTokenizer", None)), + ("bertweet", ("BertweetTokenizer", None)), + ( + "big_bird", + ( + "BigBirdTokenizer" if is_sentencepiece_available() else None, + "BigBirdTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)), + ("biogpt", ("BioGptTokenizer", None)), + ("bitnet", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")), + ("blenderbot-small", ("BlenderbotSmallTokenizer", None)), + ("blip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("blip-2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("bloom", (None, "BloomTokenizerFast" if is_tokenizers_available() else None)), + ("blt", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("bridgetower", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), + ("bros", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("byt5", ("ByT5Tokenizer", None)), + ( + "camembert", + ( + "CamembertTokenizer" if is_sentencepiece_available() else None, + "CamembertTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("canine", ("CanineTokenizer", None)), + ( + "chameleon", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("chinese_clip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ( + "clap", + ( + "RobertaTokenizer", + "RobertaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "clip", + ( + "CLIPTokenizer", + "CLIPTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "clipseg", + ( + "CLIPTokenizer", + "CLIPTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("clvp", ("ClvpTokenizer", None)), + ( + "code_llama", + ( + "CodeLlamaTokenizer" if is_sentencepiece_available() else None, + "CodeLlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)), + ("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)), + ("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)), + ("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("colqwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), + ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)), + ( + "cpm", + ( + "CpmTokenizer" if is_sentencepiece_available() else None, + "CpmTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("cpmant", ("CpmAntTokenizer", None)), + ("csm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("ctrl", ("CTRLTokenizer", None)), + ("data2vec-audio", ("Wav2Vec2CTCTokenizer", None)), + ("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), + ("dbrx", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)), + ( + "deberta-v2", + ( + "DebertaV2Tokenizer" if is_sentencepiece_available() else None, + "DebertaV2TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "deepseek_v2", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "deepseek_v3", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "deepseek_vl", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "deepseek_vl_hybrid", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("dia", ("DiaTokenizer", None)), + ( + "diffllama", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)), + ( + "dpr", + ( + "DPRQuestionEncoderTokenizer", + "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)), + ("emu3", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("ernie4_5", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("ernie4_5_moe", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)), + ("esm", ("EsmTokenizer", None)), + ( + "exaone4", + ( + "GPT2Tokenizer" if is_tokenizers_available() else None, + "GPT2TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("falcon", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("falcon_mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ( + "fastspeech2_conformer", + ("FastSpeech2ConformerTokenizer" if is_g2p_en_available() else None, None), + ), + ("flaubert", ("FlaubertTokenizer", None)), + ("flex_olmo", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)), + ("fsmt", ("FSMTTokenizer", None)), + ("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)), + ( + "gemma", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "gemma2", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "gemma3", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "gemma3_text", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "gemma3n", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "gemma3n_text", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("glm4_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("glm4v", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("glm4v_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)), + ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("gpt_neox_japanese", ("GPTNeoXJapaneseTokenizer", None)), + ("gpt_oss", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)), + ("granite", ("GPT2Tokenizer", None)), + ("granitemoe", ("GPT2Tokenizer", None)), + ("granitemoehybrid", ("GPT2Tokenizer", None)), + ("granitemoeshared", ("GPT2Tokenizer", None)), + ("grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), + ("helium", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)), + ("hubert", ("Wav2Vec2CTCTokenizer", None)), + ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), + ("idefics", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("idefics2", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("idefics3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("internvl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), + ( + "jamba", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("janus", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ( + "jetmoe", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("jukebox", ("JukeboxTokenizer", None)), + ( + "kosmos-2", + ( + "XLMRobertaTokenizer" if is_sentencepiece_available() else None, + "XLMRobertaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("kosmos-2.5", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)), + ("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)), + ("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)), + ("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)), + ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)), + ("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)), + ( + "llama", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "llama4", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "llama4_text", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("llava_onevision", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)), + ( + "longt5", + ( + "T5Tokenizer" if is_sentencepiece_available() else None, + "T5TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("luke", ("LukeTokenizer", None)), + ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)), + ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)), + ("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("mamba2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)), + ( + "mbart", + ( + "MBartTokenizer" if is_sentencepiece_available() else None, + "MBartTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "mbart50", + ( + "MBart50Tokenizer" if is_sentencepiece_available() else None, + "MBart50TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), + ("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ( + "metaclip_2", + ( + "XLMRobertaTokenizer", + "XLMRobertaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("mgp-str", ("MgpstrTokenizer", None)), + ( + "minimax", + ( + "GPT2Tokenizer" if is_sentencepiece_available() else None, + "GPT2TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "ministral", + ( + "MistralCommonTokenizer" + if is_mistral_common_available() + else ("LlamaTokenizer" if is_sentencepiece_available() else None), + "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None, + ), + ), + ( + "mistral", + ( + "MistralCommonTokenizer" + if is_mistral_common_available() + else ("LlamaTokenizer" if is_sentencepiece_available() else None), + "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None, + ), + ), + ( + "mistral3", + ( + "MistralCommonTokenizer" + if is_mistral_common_available() + else ("LlamaTokenizer" if is_sentencepiece_available() else None), + "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None, + ), + ), + ( + "mixtral", + ( + "MistralCommonTokenizer" + if is_mistral_common_available() + else ("LlamaTokenizer" if is_sentencepiece_available() else None), + "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None, + ), + ), + ("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)), + ("mm-grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)), + ("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("moonshine", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)), + ("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("mra", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), + ( + "mt5", + ( + "MT5Tokenizer" if is_sentencepiece_available() else None, + "MT5TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("musicgen", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), + ("musicgen_melody", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), + ("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)), + ("myt5", ("MyT5Tokenizer", None)), + ("nemotron", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ( + "nllb", + ( + "NllbTokenizer" if is_sentencepiece_available() else None, + "NllbTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "nllb-moe", + ( + "NllbTokenizer" if is_sentencepiece_available() else None, + "NllbTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "nystromformer", + ( + "AlbertTokenizer" if is_sentencepiece_available() else None, + "AlbertTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("olmo", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("olmo2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("olmo3", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("olmoe", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ( + "omdet-turbo", + ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None), + ), + ("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), + ( + "openai-gpt", + ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None), + ), + ("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), + ("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), + ("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("parakeet", ("ParakeetCTCTokenizer", None)), + ( + "pegasus", + ( + "PegasusTokenizer" if is_sentencepiece_available() else None, + "PegasusTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "pegasus_x", + ( + "PegasusTokenizer" if is_sentencepiece_available() else None, + "PegasusTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "perceiver", + ( + "PerceiverTokenizer", + None, + ), + ), + ( + "persimmon", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("phi", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)), + ("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("phobert", ("PhobertTokenizer", None)), + ("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), + ( + "pixtral", + ( + None, + "MistralCommonTokenizer" + if is_mistral_common_available() + else ("PreTrainedTokenizerFast" if is_tokenizers_available() else None), + ), + ), + ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)), + ("prophetnet", ("ProphetNetTokenizer", None)), + ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ( + "qwen2", + ( + "Qwen2Tokenizer", + "Qwen2TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("qwen2_5_omni", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), + ("qwen2_5_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), + ("qwen2_audio", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), + ( + "qwen2_moe", + ( + "Qwen2Tokenizer", + "Qwen2TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("qwen2_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), + ( + "qwen3", + ( + "Qwen2Tokenizer", + "Qwen2TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "qwen3_moe", + ( + "Qwen2Tokenizer", + "Qwen2TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "qwen3_next", + ( + "Qwen2Tokenizer", + "Qwen2TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("qwen3_omni_moe", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), + ("qwen3_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), + ("qwen3_vl_moe", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)), + ("rag", ("RagTokenizer", None)), + ("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)), + ( + "recurrent_gemma", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "reformer", + ( + "ReformerTokenizer" if is_sentencepiece_available() else None, + "ReformerTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "rembert", + ( + "RemBertTokenizer" if is_sentencepiece_available() else None, + "RemBertTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)), + ("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), + ( + "roberta-prelayernorm", + ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None), + ), + ("roc_bert", ("RoCBertTokenizer", None)), + ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), + ("rwkv", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ( + "seamless_m4t", + ( + "SeamlessM4TTokenizer" if is_sentencepiece_available() else None, + "SeamlessM4TTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "seamless_m4t_v2", + ( + "SeamlessM4TTokenizer" if is_sentencepiece_available() else None, + "SeamlessM4TTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "shieldgemma2", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)), + ( + "siglip2", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("smollm3", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), + ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)), + ("speech_to_text_2", ("Speech2Text2Tokenizer", None)), + ("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)), + ("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")), + ( + "squeezebert", + ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None), + ), + ("stablelm", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("starcoder2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ( + "switch_transformers", + ( + "T5Tokenizer" if is_sentencepiece_available() else None, + "T5TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "t5", + ( + "T5Tokenizer" if is_sentencepiece_available() else None, + "T5TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "t5gemma", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("tapas", ("TapasTokenizer", None)), + ("tapex", ("TapexTokenizer", None)), + ("transfo-xl", ("TransfoXLTokenizer", None)), + ("tvp", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ( + "udop", + ( + "UdopTokenizer" if is_sentencepiece_available() else None, + "UdopTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "umt5", + ( + "T5Tokenizer" if is_sentencepiece_available() else None, + "T5TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("video_llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), + ("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("vits", ("VitsTokenizer", None)), + ( + "voxtral", + ( + "MistralCommonTokenizer" if is_mistral_common_available() else None, + "PreTrainedTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None, + ), + ), + ("wav2vec2", ("Wav2Vec2CTCTokenizer", None)), + ("wav2vec2-bert", ("Wav2Vec2CTCTokenizer", None)), + ("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)), + ("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)), + ("whisper", ("WhisperTokenizer", "WhisperTokenizerFast" if is_tokenizers_available() else None)), + ("xclip", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), + ( + "xglm", + ( + "XGLMTokenizer" if is_sentencepiece_available() else None, + "XGLMTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("xlm", ("XLMTokenizer", None)), + ("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)), + ( + "xlm-roberta", + ( + "XLMRobertaTokenizer" if is_sentencepiece_available() else None, + "XLMRobertaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "xlm-roberta-xl", + ( + "XLMRobertaTokenizer" if is_sentencepiece_available() else None, + "XLMRobertaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "xlnet", + ( + "XLNetTokenizer" if is_sentencepiece_available() else None, + "XLNetTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("xlstm", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ( + "xmod", + ( + "XLMRobertaTokenizer" if is_sentencepiece_available() else None, + "XLMRobertaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "yoso", + ( + "AlbertTokenizer" if is_sentencepiece_available() else None, + "AlbertTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "zamba", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "zamba2", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ] +) + +TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES) + +CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()} + + +def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]: + if class_name == "PreTrainedTokenizerFast": + return PreTrainedTokenizerFast + + for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items(): + if class_name in tokenizers: + module_name = model_type_to_module_name(module_name) + if module_name in ["mistral", "mixtral", "ministral"] and class_name == "MistralCommonTokenizer": + module = importlib.import_module(".tokenization_mistral_common", "transformers") + else: + module = importlib.import_module(f".{module_name}", "transformers.models") + try: + return getattr(module, class_name) + except AttributeError: + continue + + for tokenizers in TOKENIZER_MAPPING._extra_content.values(): + for tokenizer in tokenizers: + if getattr(tokenizer, "__name__", None) == class_name: + return tokenizer + + # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main + # init and we return the proper dummy to get an appropriate error message. + main_module = importlib.import_module("transformers") + if hasattr(main_module, class_name): + return getattr(main_module, class_name) + + return None + + +def get_tokenizer_config( + pretrained_model_name_or_path: Union[str, os.PathLike[str]], + cache_dir: Optional[Union[str, os.PathLike[str]]] = None, + force_download: bool = False, + resume_download: Optional[bool] = None, + proxies: Optional[dict[str, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + subfolder: str = "", + **kwargs, +) -> dict[str, Any]: + """ + Loads the tokenizer configuration from a pretrained model tokenizer configuration. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + subfolder (`str`, *optional*, defaults to `""`): + In case the tokenizer config is located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + + + + Passing `token=True` is required when you want to use a private model. + + + + Returns: + `dict`: The configuration of the tokenizer. + + Examples: + + ```python + # Download configuration from huggingface.co and cache. + tokenizer_config = get_tokenizer_config("google-bert/bert-base-uncased") + # This model does not have a tokenizer config so the result will be an empty dict. + tokenizer_config = get_tokenizer_config("FacebookAI/xlm-roberta-base") + + # Save a pretrained tokenizer locally and you can reload its config + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased") + tokenizer.save_pretrained("tokenizer-test") + tokenizer_config = get_tokenizer_config("tokenizer-test") + ```""" + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + token = use_auth_token + + commit_hash = kwargs.get("_commit_hash") + resolved_config_file = cached_file( + pretrained_model_name_or_path, + TOKENIZER_CONFIG_FILE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + subfolder=subfolder, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + _commit_hash=commit_hash, + ) + if resolved_config_file is None: + logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.") + return {} + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + + with open(resolved_config_file, encoding="utf-8") as reader: + result = json.load(reader) + result["_commit_hash"] = commit_hash + return result + + +class AutoTokenizer: + r""" + This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when + created with the [`AutoTokenizer.from_pretrained`] class method. + + This class cannot be instantiated directly using `__init__()` (throws an error). + """ + + def __init__(self): + raise OSError( + "AutoTokenizer is designed to be instantiated " + "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method." + ) + + @classmethod + @replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES) + def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + r""" + Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary. + + The tokenizer class to instantiate is selected based on the `model_type` property of the config object (either + passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by + falling back to using pattern matching on `pretrained_model_name_or_path`: + + List options + + Params: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co. + - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved + using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + - A path or url to a single saved vocabulary file if and only if the tokenizer only requires a + single vocabulary file (like Bert or XLNet), e.g.: `./my_model_directory/vocab.txt`. (Not + applicable to all derived classes) + inputs (additional positional arguments, *optional*): + Will be passed along to the Tokenizer `__init__()` method. + config ([`PretrainedConfig`], *optional*) + The configuration object used to determine the tokenizer class to instantiate. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download the model weights and configuration files and override the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for + facebook/rag-token-base), specify it here. + use_fast (`bool`, *optional*, defaults to `True`): + Use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported for + a given model. If a fast tokenizer is not available for a given model, a normal Python-based tokenizer + is returned instead. + tokenizer_type (`str`, *optional*): + Tokenizer type to be loaded. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + kwargs (additional keyword arguments, *optional*): + Will be passed to the Tokenizer `__init__()` method. Can be used to set special tokens like + `bos_token`, `eos_token`, `unk_token`, `sep_token`, `pad_token`, `cls_token`, `mask_token`, + `additional_special_tokens`. See parameters in the `__init__()` for more details. + + Examples: + + ```python + >>> from transformers import AutoTokenizer + + >>> # Download vocabulary from huggingface.co and cache. + >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + + >>> # Download vocabulary from huggingface.co (user-uploaded) and cache. + >>> tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased") + + >>> # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*) + >>> # tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/") + + >>> # Download vocabulary from huggingface.co and define model-specific arguments + >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base", add_prefix_space=True) + ```""" + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token") is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + config = kwargs.pop("config", None) + kwargs["_from_auto"] = True + + use_fast = kwargs.pop("use_fast", True) + tokenizer_type = kwargs.pop("tokenizer_type", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + gguf_file = kwargs.get("gguf_file") + + # First, let's see whether the tokenizer_type is passed so that we can leverage it + if tokenizer_type is not None: + tokenizer_class = None + tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None) + + if tokenizer_class_tuple is None: + raise ValueError( + f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of " + f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES)}." + ) + + tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple + + if use_fast: + if tokenizer_fast_class_name is not None: + tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name) + else: + logger.warning( + "`use_fast` is set to `True` but the tokenizer class does not have a fast version. " + " Falling back to the slow version." + ) + if tokenizer_class is None: + tokenizer_class = tokenizer_class_from_name(tokenizer_class_name) + + if tokenizer_class is None: + raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.") + + return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + + # Next, let's try to use the tokenizer_config file to get the tokenizer class. + tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) + if "_commit_hash" in tokenizer_config: + kwargs["_commit_hash"] = tokenizer_config["_commit_hash"] + config_tokenizer_class = tokenizer_config.get("tokenizer_class") + tokenizer_auto_map = None + if "auto_map" in tokenizer_config: + if isinstance(tokenizer_config["auto_map"], (tuple, list)): + # Legacy format for dynamic tokenizers + tokenizer_auto_map = tokenizer_config["auto_map"] + else: + tokenizer_auto_map = tokenizer_config["auto_map"].get("AutoTokenizer", None) + + # If that did not work, let's try to use the config. + if config_tokenizer_class is None: + if not isinstance(config, PretrainedConfig): + if gguf_file: + gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **kwargs) + config_dict = load_gguf_checkpoint(gguf_path, return_tensors=False)["config"] + config = AutoConfig.for_model(**config_dict) + else: + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + config_tokenizer_class = config.tokenizer_class + if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map: + tokenizer_auto_map = config.auto_map["AutoTokenizer"] + + has_remote_code = tokenizer_auto_map is not None + has_local_code = type(config) in TOKENIZER_MAPPING or ( + config_tokenizer_class is not None + and ( + tokenizer_class_from_name(config_tokenizer_class) is not None + or tokenizer_class_from_name(config_tokenizer_class + "Fast") is not None + ) + ) + if has_remote_code: + if use_fast and tokenizer_auto_map[1] is not None: + class_ref = tokenizer_auto_map[1] + else: + class_ref = tokenizer_auto_map[0] + if "--" in class_ref: + upstream_repo = class_ref.split("--")[0] + else: + upstream_repo = None + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo + ) + + if has_remote_code and trust_remote_code: + tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) + _ = kwargs.pop("code_revision", None) + tokenizer_class.register_for_auto_class() + return tokenizer_class.from_pretrained( + pretrained_model_name_or_path, *inputs, trust_remote_code=trust_remote_code, **kwargs + ) + elif config_tokenizer_class is not None: + tokenizer_class = None + if use_fast and not config_tokenizer_class.endswith("Fast"): + tokenizer_class_candidate = f"{config_tokenizer_class}Fast" + tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) + if tokenizer_class is None: + tokenizer_class_candidate = config_tokenizer_class + tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate) + if tokenizer_class is None: + raise ValueError( + f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported." + ) + return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + + # Otherwise we have to be creative. + # if model is an encoder decoder, the encoder tokenizer class is used by default + if isinstance(config, EncoderDecoderConfig): + if type(config.decoder) is not type(config.encoder): + logger.warning( + f"The encoder model config class: {config.encoder.__class__} is different from the decoder model " + f"config class: {config.decoder.__class__}. It is not recommended to use the " + "`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder " + "specific tokenizer classes." + ) + config = config.encoder + + model_type = config_class_to_model_type(type(config).__name__) + if model_type is not None: + tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)] + + if tokenizer_class_fast and (use_fast or tokenizer_class_py is None): + return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + else: + if tokenizer_class_py is not None: + return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + else: + raise ValueError( + "This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed " + "in order to use this tokenizer." + ) + + raise ValueError( + f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n" + f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING)}." + ) + + @staticmethod + def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None, exist_ok=False): + """ + Register a new tokenizer in this mapping. + + + Args: + config_class ([`PretrainedConfig`]): + The configuration corresponding to the model to register. + slow_tokenizer_class ([`PretrainedTokenizer`], *optional*): + The slow tokenizer to register. + fast_tokenizer_class ([`PretrainedTokenizerFast`], *optional*): + The fast tokenizer to register. + """ + if slow_tokenizer_class is None and fast_tokenizer_class is None: + raise ValueError("You need to pass either a `slow_tokenizer_class` or a `fast_tokenizer_class") + if slow_tokenizer_class is not None and issubclass(slow_tokenizer_class, PreTrainedTokenizerFast): + raise ValueError("You passed a fast tokenizer in the `slow_tokenizer_class`.") + if fast_tokenizer_class is not None and issubclass(fast_tokenizer_class, PreTrainedTokenizer): + raise ValueError("You passed a slow tokenizer in the `fast_tokenizer_class`.") + + if ( + slow_tokenizer_class is not None + and fast_tokenizer_class is not None + and issubclass(fast_tokenizer_class, PreTrainedTokenizerFast) + and fast_tokenizer_class.slow_tokenizer_class != slow_tokenizer_class + ): + raise ValueError( + "The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not " + "consistent with the slow tokenizer class you passed (fast tokenizer has " + f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those " + "so they match!" + ) + + # Avoid resetting a set slow/fast tokenizer if we are passing just the other ones. + if config_class in TOKENIZER_MAPPING._extra_content: + existing_slow, existing_fast = TOKENIZER_MAPPING[config_class] + if slow_tokenizer_class is None: + slow_tokenizer_class = existing_slow + if fast_tokenizer_class is None: + fast_tokenizer_class = existing_fast + + TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class), exist_ok=exist_ok) + + +__all__ = ["TOKENIZER_MAPPING", "AutoTokenizer"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/video_processing_auto.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/video_processing_auto.py new file mode 100644 index 0000000000000000000000000000000000000000..84bbc8e6fdb10ea5e0a72caec2135825ff95dc20 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/video_processing_auto.py @@ -0,0 +1,393 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""AutoVideoProcessor class.""" + +import importlib +import json +import os +import warnings +from collections import OrderedDict +from typing import TYPE_CHECKING, Optional, Union + +# Build the list of all video processors +from ...configuration_utils import PretrainedConfig +from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from ...utils import CONFIG_NAME, VIDEO_PROCESSOR_NAME, cached_file, is_torchvision_available, logging +from ...utils.import_utils import requires +from ...video_processing_utils import BaseVideoProcessor +from .auto_factory import _LazyAutoMapping +from .configuration_auto import ( + CONFIG_MAPPING_NAMES, + AutoConfig, + model_type_to_module_name, + replace_list_option_in_docstrings, +) + + +logger = logging.get_logger(__name__) + + +if TYPE_CHECKING: + # This significantly improves completion suggestion performance when + # the transformers package is used with Microsoft's Pylance language server. + VIDEO_PROCESSOR_MAPPING_NAMES: OrderedDict[str, tuple[Optional[str], Optional[str]]] = OrderedDict() +else: + VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict( + [ + ("glm4v", "Glm4vVideoProcessor"), + ("instructblip", "InstructBlipVideoVideoProcessor"), + ("instructblipvideo", "InstructBlipVideoVideoProcessor"), + ("internvl", "InternVLVideoProcessor"), + ("llava_next_video", "LlavaNextVideoVideoProcessor"), + ("llava_onevision", "LlavaOnevisionVideoProcessor"), + ("perception_lm", "PerceptionLMVideoProcessor"), + ("qwen2_5_omni", "Qwen2VLVideoProcessor"), + ("qwen2_5_vl", "Qwen2VLVideoProcessor"), + ("qwen2_vl", "Qwen2VLVideoProcessor"), + ("qwen3_omni_moe", "Qwen2VLVideoProcessor"), + ("qwen3_vl", "Qwen3VLVideoProcessor"), + ("qwen3_vl_moe", "Qwen3VLVideoProcessor"), + ("sam2_video", "Sam2VideoVideoProcessor"), + ("smolvlm", "SmolVLMVideoProcessor"), + ("video_llava", "VideoLlavaVideoProcessor"), + ("vjepa2", "VJEPA2VideoProcessor"), + ] + ) + +for model_type, video_processors in VIDEO_PROCESSOR_MAPPING_NAMES.items(): + fast_video_processor_class = video_processors + + # If the torchvision is not available, we set it to None + if not is_torchvision_available(): + fast_video_processor_class = None + + VIDEO_PROCESSOR_MAPPING_NAMES[model_type] = fast_video_processor_class + +VIDEO_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, VIDEO_PROCESSOR_MAPPING_NAMES) + + +def video_processor_class_from_name(class_name: str): + for module_name, extractors in VIDEO_PROCESSOR_MAPPING_NAMES.items(): + if class_name in extractors: + module_name = model_type_to_module_name(module_name) + + module = importlib.import_module(f".{module_name}", "transformers.models") + try: + return getattr(module, class_name) + except AttributeError: + continue + + for extractor in VIDEO_PROCESSOR_MAPPING._extra_content.values(): + if getattr(extractor, "__name__", None) == class_name: + return extractor + + # We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main + # init and we return the proper dummy to get an appropriate error message. + main_module = importlib.import_module("transformers") + if hasattr(main_module, class_name): + return getattr(main_module, class_name) + + return None + + +def get_video_processor_config( + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: Optional[bool] = None, + proxies: Optional[dict[str, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + **kwargs, +): + """ + Loads the video processor configuration from a pretrained model video processor configuration. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the video processor configuration from local files. + + + + Passing `token=True` is required when you want to use a private model. + + + + Returns: + `Dict`: The configuration of the video processor. + + Examples: + + ```python + # Download configuration from huggingface.co and cache. + video_processor_config = get_video_processor_config("llava-hf/llava-onevision-qwen2-0.5b-ov-hf") + # This model does not have a video processor config so the result will be an empty dict. + video_processor_config = get_video_processor_config("FacebookAI/xlm-roberta-base") + + # Save a pretrained video processor locally and you can reload its config + from transformers import AutoVideoProcessor + + video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf") + video_processor.save_pretrained("video-processor-test") + video_processor = get_video_processor_config("video-processor-test") + ```""" + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + token = use_auth_token + + resolved_config_file = cached_file( + pretrained_model_name_or_path, + VIDEO_PROCESSOR_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + ) + if resolved_config_file is None: + logger.info( + "Could not locate the video processor configuration file, will try to use the model config instead." + ) + return {} + + with open(resolved_config_file, encoding="utf-8") as reader: + return json.load(reader) + + +@requires(backends=("vision", "torchvision")) +class AutoVideoProcessor: + r""" + This is a generic video processor class that will be instantiated as one of the video processor classes of the + library when created with the [`AutoVideoProcessor.from_pretrained`] class method. + + This class cannot be instantiated directly using `__init__()` (throws an error). + """ + + def __init__(self): + raise OSError( + "AutoVideoProcessor is designed to be instantiated " + "using the `AutoVideoProcessor.from_pretrained(pretrained_model_name_or_path)` method." + ) + + @classmethod + @replace_list_option_in_docstrings(VIDEO_PROCESSOR_MAPPING_NAMES) + def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + r""" + Instantiate one of the video processor classes of the library from a pretrained model vocabulary. + + The video processor class to instantiate is selected based on the `model_type` property of the config object + (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's + missing, by falling back to using pattern matching on `pretrained_model_name_or_path`: + + List options + + Params: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained video_processor hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a video processor file saved using the + [`~video_processing_utils.BaseVideoProcessor.save_pretrained`] method, e.g., + `./my_model_directory/`. + - a path or url to a saved video processor JSON *file*, e.g., + `./my_model_directory/preprocessor_config.json`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model video processor should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the video processor files and override the cached versions if + they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `hf auth login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + If `False`, then this function returns just the final video processor object. If `True`, then this + functions returns a `Tuple(video_processor, unused_kwargs)` where *unused_kwargs* is a dictionary + consisting of the key/value pairs whose keys are not video processor attributes: i.e., the part of + `kwargs` which has not been used to update `video_processor` and is otherwise ignored. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + kwargs (`dict[str, Any]`, *optional*): + The values in kwargs of any keys which are video processor attributes will be used to override the + loaded values. Behavior concerning key/value pairs whose keys are *not* video processor attributes is + controlled by the `return_unused_kwargs` keyword parameter. + + + + Passing `token=True` is required when you want to use a private model. + + + + Examples: + + ```python + >>> from transformers import AutoVideoProcessor + + >>> # Download video processor from huggingface.co and cache. + >>> video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf") + + >>> # If video processor files are in a directory (e.g. video processor was saved using *save_pretrained('./test/saved_model/')*) + >>> # video_processor = AutoVideoProcessor.from_pretrained("./test/saved_model/") + ```""" + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token") is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + config = kwargs.pop("config", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + kwargs["_from_auto"] = True + + config_dict, _ = BaseVideoProcessor.get_video_processor_dict(pretrained_model_name_or_path, **kwargs) + video_processor_class = config_dict.get("video_processor_type", None) + video_processor_auto_map = None + if "AutoVideoProcessor" in config_dict.get("auto_map", {}): + video_processor_auto_map = config_dict["auto_map"]["AutoVideoProcessor"] + + # If we still don't have the video processor class, check if we're loading from a previous image processor config + # and if so, infer the video processor class from there. + if video_processor_class is None and video_processor_auto_map is None: + image_processor_class = config_dict.pop("image_processor_type", None) + if image_processor_class is not None: + video_processor_class_inferred = image_processor_class.replace("ImageProcessor", "VideoProcessor") + + # Some models have different image processors, e.g. InternVL uses GotOCRImageProcessor + # We cannot use GotOCRVideoProcessor when falling back for BC and should try to infer from config later on + if video_processor_class_inferred in VIDEO_PROCESSOR_MAPPING_NAMES.values(): + video_processor_class = video_processor_class_inferred + if "AutoImageProcessor" in config_dict.get("auto_map", {}): + image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"] + video_processor_auto_map = image_processor_auto_map.replace("ImageProcessor", "VideoProcessor") + + # If we don't find the video processor class in the video processor config, let's try the model config. + if video_processor_class is None and video_processor_auto_map is None: + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + # It could be in `config.video_processor_type`` + video_processor_class = getattr(config, "video_processor_type", None) + if hasattr(config, "auto_map") and "AutoVideoProcessor" in config.auto_map: + video_processor_auto_map = config.auto_map["AutoVideoProcessor"] + + if video_processor_class is not None: + video_processor_class = video_processor_class_from_name(video_processor_class) + + has_remote_code = video_processor_auto_map is not None + has_local_code = video_processor_class is not None or type(config) in VIDEO_PROCESSOR_MAPPING + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code + ) + + if has_remote_code and trust_remote_code: + class_ref = video_processor_auto_map + video_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) + _ = kwargs.pop("code_revision", None) + video_processor_class.register_for_auto_class() + return video_processor_class.from_dict(config_dict, **kwargs) + elif video_processor_class is not None: + return video_processor_class.from_dict(config_dict, **kwargs) + # Last try: we use the VIDEO_PROCESSOR_MAPPING. + elif type(config) in VIDEO_PROCESSOR_MAPPING: + video_processor_class = VIDEO_PROCESSOR_MAPPING[type(config)] + + if video_processor_class is not None: + return video_processor_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + else: + raise ValueError( + "This video processor cannot be instantiated. Please make sure you have `torchvision` installed." + ) + + raise ValueError( + f"Unrecognized video processor in {pretrained_model_name_or_path}. Should have a " + f"`video_processor_type` key in its {VIDEO_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following " + f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in VIDEO_PROCESSOR_MAPPING_NAMES)}" + ) + + @staticmethod + def register( + config_class, + video_processor_class, + exist_ok=False, + ): + """ + Register a new video processor for this class. + + Args: + config_class ([`PretrainedConfig`]): + The configuration corresponding to the model to register. + video_processor_class ([`BaseVideoProcessor`]): + The video processor to register. + """ + VIDEO_PROCESSOR_MAPPING.register(config_class, video_processor_class, exist_ok=exist_ok) + + +__all__ = ["VIDEO_PROCESSOR_MAPPING", "AutoVideoProcessor"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5296fc47fc423c300cf6c43bccb5daafd6d134a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_bark import * + from .modeling_bark import * + from .processing_bark import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/configuration_bark.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/configuration_bark.py new file mode 100644 index 0000000000000000000000000000000000000000..25787a90d61cf15a41bea24e9555f85c34a92ac8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/configuration_bark.py @@ -0,0 +1,303 @@ +# coding=utf-8 +# Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BARK model configuration""" + +from typing import Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import add_start_docstrings, logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +BARK_SUBMODELCONFIG_START_DOCSTRING = """ + This is the configuration class to store the configuration of a [`{model}`]. It is used to instantiate the model + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Bark [suno/bark](https://huggingface.co/suno/bark) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + block_size (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + input_vocab_size (`int`, *optional*, defaults to 10_048): + Vocabulary size of a Bark sub-model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`{model}`]. Defaults to 10_048 but should be carefully thought with + regards to the chosen sub-model. + output_vocab_size (`int`, *optional*, defaults to 10_048): + Output vocabulary size of a Bark sub-model. Defines the number of different tokens that can be represented + by the: `output_ids` when passing forward a [`{model}`]. Defaults to 10_048 but should be carefully thought + with regards to the chosen sub-model. + num_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the given sub-model. + num_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer architecture. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the "intermediate" (often named feed-forward) layer in the architecture. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias in the linear layers and layer norm layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). +""" + + +class BarkSubModelConfig(PretrainedConfig): + keys_to_ignore_at_inference = ["past_key_values"] + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "vocab_size": "input_vocab_size", + "window_size": "block_size", + } + + def __init__( + self, + block_size=1024, + input_vocab_size=10_048, + output_vocab_size=10_048, + num_layers=12, + num_heads=12, + hidden_size=768, + dropout=0.0, + bias=True, # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + initializer_range=0.02, + use_cache=True, + **kwargs, + ): + self.block_size = block_size + self.input_vocab_size = input_vocab_size + self.output_vocab_size = output_vocab_size + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_size = hidden_size + self.dropout = dropout + self.bias = bias + self.use_cache = use_cache + self.initializer_range = initializer_range + + super().__init__(**kwargs) + + +@add_start_docstrings( + BARK_SUBMODELCONFIG_START_DOCSTRING.format(config="BarkSemanticConfig", model="BarkSemanticModel"), + """ + Example: + + ```python + >>> from transformers import BarkSemanticConfig, BarkSemanticModel + + >>> # Initializing a Bark sub-module style configuration + >>> configuration = BarkSemanticConfig() + + >>> # Initializing a model (with random weights) from the suno/bark style configuration + >>> model = BarkSemanticModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""", +) +class BarkSemanticConfig(BarkSubModelConfig): + model_type = "semantic" + base_config_key = "semantic_config" + + +@add_start_docstrings( + BARK_SUBMODELCONFIG_START_DOCSTRING.format(config="BarkCoarseConfig", model="BarkCoarseModel"), + """ + Example: + + ```python + >>> from transformers import BarkCoarseConfig, BarkCoarseModel + + >>> # Initializing a Bark sub-module style configuration + >>> configuration = BarkCoarseConfig() + + >>> # Initializing a model (with random weights) from the suno/bark style configuration + >>> model = BarkCoarseModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""", +) +class BarkCoarseConfig(BarkSubModelConfig): + model_type = "coarse_acoustics" + base_config_key = "coarse_acoustics_config" + + +@add_start_docstrings( + BARK_SUBMODELCONFIG_START_DOCSTRING.format(config="BarkFineConfig", model="BarkFineModel"), + """ + n_codes_total (`int`, *optional*, defaults to 8): + The total number of audio codebooks predicted. Used in the fine acoustics sub-model. + n_codes_given (`int`, *optional*, defaults to 1): + The number of audio codebooks predicted in the coarse acoustics sub-model. Used in the acoustics + sub-models. + Example: + + ```python + >>> from transformers import BarkFineConfig, BarkFineModel + + >>> # Initializing a Bark sub-module style configuration + >>> configuration = BarkFineConfig() + + >>> # Initializing a model (with random weights) from the suno/bark style configuration + >>> model = BarkFineModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""", +) +class BarkFineConfig(BarkSubModelConfig): + model_type = "fine_acoustics" + base_config_key = "fine_acoustics_config" + + def __init__(self, tie_word_embeddings=True, n_codes_total=8, n_codes_given=1, **kwargs): + self.n_codes_total = n_codes_total + self.n_codes_given = n_codes_given + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class BarkConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`BarkModel`]. It is used to instantiate a Bark + model according to the specified sub-models configurations, defining the model architecture. + + Instantiating a configuration with the defaults will yield a similar configuration to that of the Bark + [suno/bark](https://huggingface.co/suno/bark) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + semantic_config ([`BarkSemanticConfig`], *optional*): + Configuration of the underlying semantic sub-model. + coarse_acoustics_config ([`BarkCoarseConfig`], *optional*): + Configuration of the underlying coarse acoustics sub-model. + fine_acoustics_config ([`BarkFineConfig`], *optional*): + Configuration of the underlying fine acoustics sub-model. + codec_config ([`AutoConfig`], *optional*): + Configuration of the underlying codec sub-model. + + Example: + + ```python + >>> from transformers import ( + ... BarkSemanticConfig, + ... BarkCoarseConfig, + ... BarkFineConfig, + ... BarkModel, + ... BarkConfig, + ... AutoConfig, + ... ) + + >>> # Initializing Bark sub-modules configurations. + >>> semantic_config = BarkSemanticConfig() + >>> coarse_acoustics_config = BarkCoarseConfig() + >>> fine_acoustics_config = BarkFineConfig() + >>> codec_config = AutoConfig.from_pretrained("facebook/encodec_24khz") + + + >>> # Initializing a Bark module style configuration + >>> configuration = BarkConfig.from_sub_model_configs( + ... semantic_config, coarse_acoustics_config, fine_acoustics_config, codec_config + ... ) + + >>> # Initializing a model (with random weights) + >>> model = BarkModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "bark" + sub_configs = { + "semantic_config": BarkSemanticConfig, + "coarse_acoustics_config": BarkCoarseConfig, + "fine_acoustics_config": BarkFineConfig, + "codec_config": AutoConfig, + } + + def __init__( + self, + semantic_config: Optional[dict] = None, + coarse_acoustics_config: Optional[dict] = None, + fine_acoustics_config: Optional[dict] = None, + codec_config: Optional[dict] = None, + initializer_range=0.02, + **kwargs, + ): + if semantic_config is None: + semantic_config = {} + logger.info("semantic_config is None. initializing the semantic model with default values.") + + if coarse_acoustics_config is None: + coarse_acoustics_config = {} + logger.info("coarse_acoustics_config is None. initializing the coarse model with default values.") + + if fine_acoustics_config is None: + fine_acoustics_config = {} + logger.info("fine_acoustics_config is None. initializing the fine model with default values.") + + if codec_config is None: + codec_config = {} + logger.info("codec_config is None. initializing the codec model with default values.") + + self.semantic_config = BarkSemanticConfig(**semantic_config) + self.coarse_acoustics_config = BarkCoarseConfig(**coarse_acoustics_config) + self.fine_acoustics_config = BarkFineConfig(**fine_acoustics_config) + codec_model_type = codec_config.get("model_type", "encodec") + self.codec_config = CONFIG_MAPPING[codec_model_type](**codec_config) + + self.initializer_range = initializer_range + + super().__init__(**kwargs) + + @classmethod + def from_sub_model_configs( + cls, + semantic_config: BarkSemanticConfig, + coarse_acoustics_config: BarkCoarseConfig, + fine_acoustics_config: BarkFineConfig, + codec_config: PretrainedConfig, + **kwargs, + ): + r""" + Instantiate a [`BarkConfig`] (or a derived class) from bark sub-models configuration. + + Returns: + [`BarkConfig`]: An instance of a configuration object + """ + return cls( + semantic_config=semantic_config.to_dict(), + coarse_acoustics_config=coarse_acoustics_config.to_dict(), + fine_acoustics_config=fine_acoustics_config.to_dict(), + codec_config=codec_config.to_dict(), + **kwargs, + ) + + +__all__ = ["BarkCoarseConfig", "BarkConfig", "BarkFineConfig", "BarkSemanticConfig"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/generation_configuration_bark.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/generation_configuration_bark.py new file mode 100644 index 0000000000000000000000000000000000000000..0fa68184c88526ccc793a336a64bce798f6d7759 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/generation_configuration_bark.py @@ -0,0 +1,330 @@ +# coding=utf-8 +# Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BARK model generation configuration""" + +import copy +from typing import Optional + +from ...generation.configuration_utils import GenerationConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class BarkSemanticGenerationConfig(GenerationConfig): + model_type = "semantic" + + def __init__( + self, + eos_token_id=10_000, + renormalize_logits=True, + max_new_tokens=768, + output_scores=False, + return_dict_in_generate=False, + output_hidden_states=False, + output_attentions=False, + temperature=1.0, + do_sample=False, + text_encoding_offset=10_048, + text_pad_token=129_595, + semantic_infer_token=129_599, + semantic_vocab_size=10_000, + max_input_semantic_length=256, + semantic_rate_hz=49.9, + min_eos_p=None, + **kwargs, + ): + """Class that holds a generation configuration for [`BarkSemanticModel`]. + + This configuration inherit from [`GenerationConfig`] and can be used to control the model generation. Read the + documentation from [`GenerationConfig`] for more information. + + Args: + eos_token_id (`int`, *optional*, defaults to 10_000): + The id of the *end-of-sequence* token. + renormalize_logits (`bool`, *optional*, defaults to `True`): + Whether to renormalize the logits after applying all the logits processors (including the + custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the + score logits are normalized but some logit processors break the normalization. + max_new_tokens (`int`, *optional*, defaults to 768): + The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + temperature (`float`, *optional*, defaults to 1.0): + The value used to modulate the next token probabilities. + do_sample (`bool`, *optional*, defaults to `False`): + Whether or not to use sampling ; use greedy decoding otherwise. + text_encoding_offset (`int`, *optional*, defaults to 10_048): + Text encoding offset. + text_pad_token (`int`, *optional*, defaults to 129_595): + Text pad token. + semantic_infer_token (`int`, *optional*, defaults to 129_599): + Semantic infer token. + semantic_vocab_size (`int`, *optional*, defaults to 10_000): + Semantic vocab size. + max_input_semantic_length (`int`, *optional*, defaults to 256): + Max length of semantic input vector. + semantic_rate_hz (`float`, *optional*, defaults to 49.9): + Semantic rate in Hertz. + min_eos_p (`float`, *optional*): + Minimum threshold of the probability of the EOS token for it to be sampled. This is an early stopping + strategy to mitigate potential unwanted generations at the end of a prompt. The original implementation + suggests a default value of 0.2. + """ + super().__init__( + temperature=temperature, + do_sample=do_sample, + eos_token_id=eos_token_id, + renormalize_logits=renormalize_logits, + max_new_tokens=max_new_tokens, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + **kwargs, + ) + + self.text_encoding_offset = text_encoding_offset + self.text_pad_token = text_pad_token + self.semantic_pad_token = eos_token_id + self.semantic_infer_token = semantic_infer_token + self.semantic_vocab_size = semantic_vocab_size + self.max_input_semantic_length = max_input_semantic_length + self.semantic_rate_hz = semantic_rate_hz + self.min_eos_p = min_eos_p + + +class BarkCoarseGenerationConfig(GenerationConfig): + model_type = "coarse_acoustics" + + def __init__( + self, + renormalize_logits=True, + output_scores=False, + return_dict_in_generate=False, + output_hidden_states=False, + output_attentions=False, + temperature=1.0, + do_sample=False, + coarse_semantic_pad_token=12_048, + coarse_rate_hz=75, + n_coarse_codebooks=2, + coarse_infer_token=12_050, + max_coarse_input_length=256, + max_coarse_history: int = 630, + sliding_window_len: int = 60, + **kwargs, + ): + """Class that holds a generation configuration for [`BarkCoarseModel`]. + + This configuration inherit from [`GenerationConfig`] and can be used to control the model generation. Read the + documentation from [`GenerationConfig`] for more information. + + Args: + renormalize_logits (`bool`, *optional*, defaults to `True`): + Whether to renormalize the logits after applying all the logits processors (including the + custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the + score logits are normalized but some logit processors break the normalization. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + temperature (`float`, *optional*, defaults to 1.0): + The value used to modulate the next token probabilities. + do_sample (`bool`, *optional*, defaults to `False`): + Whether or not to use sampling ; use greedy decoding otherwise. + coarse_semantic_pad_token (`int`, *optional*, defaults to 12_048): + Coarse semantic pad token. + coarse_rate_hz (`int`, *optional*, defaults to 75): + Coarse rate in Hertz. + n_coarse_codebooks (`int`, *optional*, defaults to 2): + Number of coarse codebooks. + coarse_infer_token (`int`, *optional*, defaults to 12_050): + Coarse infer token. + max_coarse_input_length (`int`, *optional*, defaults to 256): + Max length of input coarse vector. + max_coarse_history (`int`, *optional*, defaults to 630): + Max length of the output of the coarse acoustics model used in the fine generation step. + sliding_window_len (`int`, *optional*, defaults to 60): + The coarse generation step uses a sliding window to generate raw audio. + """ + super().__init__( + temperature=temperature, + do_sample=do_sample, + renormalize_logits=renormalize_logits, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + **kwargs, + ) + + self.coarse_semantic_pad_token = coarse_semantic_pad_token + self.coarse_rate_hz = coarse_rate_hz + self.n_coarse_codebooks = n_coarse_codebooks + self.coarse_infer_token = coarse_infer_token + self.max_coarse_input_length = max_coarse_input_length + self.max_coarse_history = max_coarse_history + self.sliding_window_len = sliding_window_len + + +class BarkFineGenerationConfig(GenerationConfig): + model_type = "fine_acoustics" + + def __init__( + self, + temperature=1.0, + max_fine_history_length=512, + max_fine_input_length=1024, + n_fine_codebooks=8, + **kwargs, + ): + """Class that holds a generation configuration for [`BarkFineModel`]. + + [`BarkFineModel`] is an autoencoder model, so should not usually be used for generation. However, under the + hood, it uses `temperature` when used by [`BarkModel`] + + This configuration inherit from [`GenerationConfig`] and can be used to control the model generation. Read the + documentation from [`GenerationConfig`] for more information. + + Args: + temperature (`float`, *optional*): + The value used to modulate the next token probabilities. + max_fine_history_length (`int`, *optional*, defaults to 512): + Max length of the fine history vector. + max_fine_input_length (`int`, *optional*, defaults to 1024): + Max length of fine input vector. + n_fine_codebooks (`int`, *optional*, defaults to 8): + Number of codebooks used. + """ + super().__init__(temperature=temperature) + + self.max_fine_history_length = max_fine_history_length + self.max_fine_input_length = max_fine_input_length + self.n_fine_codebooks = n_fine_codebooks + + def validate(self, **kwargs): + """ + Overrides GenerationConfig.validate because BarkFineGenerationConfig don't use any parameters outside + temperature. + """ + pass + + +class BarkGenerationConfig(GenerationConfig): + model_type = "bark" + + # TODO (joao): nested from_dict + + def __init__( + self, + semantic_config: Optional[dict] = None, + coarse_acoustics_config: Optional[dict] = None, + fine_acoustics_config: Optional[dict] = None, + sample_rate=24_000, + codebook_size=1024, + **kwargs, + ): + """Class that holds a generation configuration for [`BarkModel`]. + + The [`BarkModel`] does not have a `generate` method, but uses this class to generate speeches with a nested + [`BarkGenerationConfig`] which uses [`BarkSemanticGenerationConfig`], [`BarkCoarseGenerationConfig`], + [`BarkFineGenerationConfig`]. + + This configuration inherit from [`GenerationConfig`] and can be used to control the model generation. Read the + documentation from [`GenerationConfig`] for more information. + + Args: + semantic_config (`Dict`, *optional*): + Semantic generation configuration. + coarse_acoustics_config (`Dict`, *optional*): + Coarse generation configuration. + fine_acoustics_config (`Dict`, *optional*): + Fine generation configuration. + sample_rate (`int`, *optional*, defaults to 24_000): + Sample rate. + codebook_size (`int`, *optional*, defaults to 1024): + Vector length for each codebook. + """ + if semantic_config is None: + semantic_config = {} + logger.info("semantic_config is None. initializing the semantic model with default values.") + + if coarse_acoustics_config is None: + coarse_acoustics_config = {} + logger.info("coarse_acoustics_config is None. initializing the coarse model with default values.") + + if fine_acoustics_config is None: + fine_acoustics_config = {} + logger.info("fine_acoustics_config is None. initializing the fine model with default values.") + + self.semantic_config = BarkSemanticGenerationConfig(**semantic_config) + self.coarse_acoustics_config = BarkCoarseGenerationConfig(**coarse_acoustics_config) + self.fine_acoustics_config = BarkFineGenerationConfig(**fine_acoustics_config) + + self.sample_rate = sample_rate + self.codebook_size = codebook_size + + @classmethod + def from_sub_model_configs( + cls, + semantic_config: BarkSemanticGenerationConfig, + coarse_acoustics_config: BarkCoarseGenerationConfig, + fine_acoustics_config: BarkFineGenerationConfig, + **kwargs, + ): + r""" + Instantiate a [`BarkGenerationConfig`] (or a derived class) from bark sub-models generation configuration. + + Returns: + [`BarkGenerationConfig`]: An instance of a configuration object + """ + return cls( + semantic_config=semantic_config.to_dict(), + coarse_acoustics_config=coarse_acoustics_config.to_dict(), + fine_acoustics_config=fine_acoustics_config.to_dict(), + **kwargs, + ) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + + output["semantic_config"] = self.semantic_config.to_dict() + output["coarse_acoustics_config"] = self.coarse_acoustics_config.to_dict() + output["fine_acoustics_config"] = self.fine_acoustics_config.to_dict() + + output["model_type"] = self.__class__.model_type + return output diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/modeling_bark.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/modeling_bark.py new file mode 100644 index 0000000000000000000000000000000000000000..af57f7826734f7ae4bd5cdcac5c9ba3d29bbf969 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/modeling_bark.py @@ -0,0 +1,1628 @@ +# coding=utf-8 +# Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BARK model.""" + +import math +import warnings +from typing import Optional, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...generation.logits_process import ( + AlternatingCodebooksLogitsProcessor, + BarkEosPrioritizerLogitsProcessor, + SuppressTokensLogitsProcessor, +) +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput +from ...modeling_utils import PreTrainedModel, get_parameter_device +from ...utils import ( + auto_docstring, + is_accelerate_available, + is_torch_accelerator_available, + logging, +) +from ..auto import AutoModel +from .configuration_bark import ( + BarkCoarseConfig, + BarkConfig, + BarkFineConfig, + BarkSemanticConfig, + BarkSubModelConfig, +) +from .generation_configuration_bark import ( + BarkCoarseGenerationConfig, + BarkFineGenerationConfig, + BarkSemanticGenerationConfig, +) + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + + +class BarkSelfAttention(nn.Module): + # adapted from GPTNeoSelfAttention and Bark code + # BarkSelfAttention can have two attention type, i.e full attention or causal attention + + def __init__(self, config, is_causal=False, layer_idx=None): + super().__init__() + + # regularization + self.dropout = config.dropout + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.embed_dim // self.num_heads + + if config.hidden_size % config.num_heads != 0: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + # key, query, value projections for all heads, but in a batch + self.att_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.bias) + # output projection + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.bias) + + self.is_causal = is_causal + self.layer_idx = layer_idx + if is_causal: + block_size = config.block_size + bias = torch.tril(torch.ones((block_size, block_size), dtype=bool)).view(1, 1, block_size, block_size) + self.register_buffer("bias", bias) + + # Copied from transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoSelfAttention._split_heads + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + + # re-assemble all head outputs side by side + # (batch, num_heads, seq_len, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size) + tensor = tensor.transpose(1, 2).contiguous() + tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,)) + + return tensor + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + # unlike GPTNeo's SelfAttention, divide by the square root of the dimension of the query and the key + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * (1.0 / math.sqrt(self.head_dim)) + + if self.is_causal: + query_length, key_length = query.size(-2), key.size(-2) + + # fill the upper left part of the attention weights with inf + attn_weights = attn_weights.masked_fill( + self.bias[:, :, key_length - query_length : key_length, :key_length] == 0, + torch.finfo(attn_weights.dtype).min, + ) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = attn_weights.to(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + # (batch, num_heads, seq_len, seq_len) x (batch, num_heads, seq_len, attn_head_size) + # -> (batch, num_heads, seq_len, attn_head_size) + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def forward( + self, + hidden_states, + attention_mask=None, + past_key_values=None, + head_mask=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if past_key_values is not None: + key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position}) + + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + return attn_output, attn_weights + + +class BarkSelfFlashAttention2(BarkSelfAttention): + """ + Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features) + return tensor + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + # re-assemble all head outputs side by side + # (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size) + tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,)) + return tensor + + def forward( + self, + hidden_states, + attention_mask=None, + past_key_values=None, + head_mask=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + batch_size, query_len, _ = hidden_states.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if past_key_values is not None: + key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position}) + + attn_output = _flash_attention_forward( + query, + key, + value, + attention_mask, + query_len, + dropout=self.dropout if self.training else 0.0, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + return attn_output, None + + +BARK_ATTENTION_CLASSES = { + "eager": BarkSelfAttention, + "flash_attention_2": BarkSelfFlashAttention2, +} + + +class BarkMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.in_proj = nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=config.bias) + self.out_proj = nn.Linear(4 * config.hidden_size, config.hidden_size, bias=config.bias) + self.dropout = nn.Dropout(config.dropout) + self.gelu = nn.GELU() + + def forward(self, hidden_states): + hidden_states = self.in_proj(hidden_states) + hidden_states = self.gelu(hidden_states) + hidden_states = self.out_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class BarkBlock(GradientCheckpointingLayer): + def __init__(self, config, is_causal=False, layer_idx=None): + super().__init__() + + if is_causal: + # if causal, the layerNorm bias is optional to stick with Bark choice of leaving optional bias + # in AutoRegressive models (corresponding to the "Text" and the "Coarse" modules) + self.layernorm_1 = nn.LayerNorm(config.hidden_size, bias=config.bias) + self.layernorm_2 = nn.LayerNorm(config.hidden_size, bias=config.bias) + else: + self.layernorm_1 = nn.LayerNorm(config.hidden_size) + self.layernorm_2 = nn.LayerNorm(config.hidden_size) + + self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation]( + config, is_causal=is_causal, layer_idx=layer_idx + ) + + self.mlp = BarkMLP(config) + + def forward( + self, + hidden_states, + past_key_values=None, + attention_mask=None, + head_mask=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + intermediary_hidden_states = self.layernorm_1(hidden_states) + + attn_outputs = self.attn( + intermediary_hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + attn_output = attn_outputs[0] # output_attn: output, present_key_values, (attn_weights) + outputs = attn_outputs[1:] + + intermediary_hidden_states = hidden_states + attn_output + intermediary_hidden_states = intermediary_hidden_states + self.mlp( + self.layernorm_2(intermediary_hidden_states) + ) + + return (intermediary_hidden_states,) + outputs + + +@auto_docstring +class BarkPreTrainedModel(PreTrainedModel): + config: BarkConfig + supports_gradient_checkpointing = False + _supports_flash_attn = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear,)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + @property + def device(self) -> torch.device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + + # if has _hf_hook, has been offloaded so the device has to be found in the hook + if not hasattr(self, "_hf_hook"): + return get_parameter_device(self) + for module in self.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + + return get_parameter_device(self) + + +# GPT2-like autoregressive model +class BarkCausalModel(BarkPreTrainedModel, GenerationMixin): + config: BarkSubModelConfig + + def __init__(self, config): + super().__init__(config) + self.config = config + + # initialize as an autoregressive GPT-like model + self.input_embeds_layer = nn.Embedding(config.input_vocab_size, config.hidden_size) + self.position_embeds_layer = nn.Embedding(config.block_size, config.hidden_size) + + self.drop = nn.Dropout(config.dropout) + + self.layers = nn.ModuleList([BarkBlock(config, is_causal=True, layer_idx=i) for i in range(config.num_layers)]) + + self.layernorm_final = nn.LayerNorm(config.hidden_size, bias=config.bias) + + self.lm_head = nn.Linear(config.hidden_size, config.output_vocab_size, bias=False) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + # NOTE: get_output_embeddings() must return None to prevent accidental weight tying. + # See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400 + return None + + def get_input_embeddings(self): + return self.input_embeds_layer + + def set_input_embeddings(self, new_embeddings): + self.input_embeds_layer = new_embeddings + + def prepare_inputs_for_generation( + self, + input_ids, + attention_mask=None, + input_embeds=None, + past_key_values=None, + position_ids=None, + use_cache=None, + cache_position=None, + **kwargs, + ): + # Overwritten -- bark uses `input_embeds` not `inputS_embeds` + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + attention_mask=attention_mask, + inputs_embeds=input_embeds, + past_key_values=past_key_values, + position_ids=position_ids, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + model_inputs["input_embeds"] = model_inputs.pop("inputs_embeds", None) + return model_inputs + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + input_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], CausalLMOutputWithPast]: + r""" + input_embeds (`torch.FloatTensor` of shape `(batch_size, input_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + Here, due to `Bark` particularities, if `past_key_values` is used, `input_embeds` will be ignored and you + have to use `input_ids`. If `past_key_values` is not used and `use_cache` is set to `True`, `input_embeds` + is used in priority instead of `input_ids`. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + loss = None + if labels is not None: + raise NotImplementedError( + "Training is not implemented yet for Bark - ensure you do not pass `labels` to the model." + ) + + # Verify if input_embeds already exists + # then compute embeddings. + if input_ids is not None and input_embeds is not None: + raise ValueError("You cannot specify both input_ids and input_embeds at the same time") + elif input_embeds is not None and past_key_values is None: + # we want to return the input_embeds in priority so that it is in line with a weird hack + # of Bark which concatenate two bits of the input_embeds on the first forward pass of the semantic model + pass + elif input_ids is not None: + input_embeds = self.input_embeds_layer(input_ids) # token embeddings of shape (b, t, n_embd) + elif input_embeds is not None: + pass + else: + raise ValueError("You have to specify either input_ids or input_embeds") + + input_shape = input_embeds.size()[:-1] + batch_size = input_embeds.shape[0] + seq_length = input_shape[-1] + + device = input_ids.device if input_ids is not None else input_embeds.device + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + + if position_ids is None: + position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) # shape (1, seq_length) + + position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd) + + # Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + else: + attention_mask = attention_mask.view(batch_size, -1) + # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length] + # from_seq_length is 1 to easily broadcast + attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_heads x N x N + # head_mask has shape num_layers x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + + hidden_states = self.drop(input_embeds + position_embeds) + output_shape = input_shape + (hidden_states.size(-1),) + + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, block in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[1],) + + hidden_states = self.layernorm_final(hidden_states) + + hidden_states = hidden_states.view(output_shape) + + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + logits = self.lm_head(hidden_states) + + if not return_dict: + return tuple( + v for v in [None, logits, past_key_values, all_hidden_states, all_self_attentions] if v is not None + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring( + custom_intro=""" + Bark semantic (or text) model. It shares the same architecture as the coarse model. + It is a GPT-2 like autoregressive model with a language modeling head on top. + """ +) +class BarkSemanticModel(BarkCausalModel): + base_model_prefix = "semantic" + config: BarkSemanticConfig + + def generate( + self, + input_ids: torch.Tensor, + semantic_generation_config: Optional[BarkSemanticGenerationConfig] = None, + history_prompt: Optional[dict[str, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.LongTensor: + """ + Generates text semantic tokens from an input prompt and an additional optional `Bark` speaker prompt. + + Args: + input_ids (`Optional[torch.Tensor]` of shape (batch_size, seq_len), *optional*): + Input ids, i.e tokenized input sentences. Will be truncated up to + semantic_generation_config.max_input_semantic_length tokens. Note that the output audios will be as + long as the longest generation among the batch. + semantic_generation_config (`BarkSemanticGenerationConfig`): + Generation config indicating how to generate the semantic tokens. + history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*): + Optional `Bark` speaker prompt. + attention_mask (`Optional[torch.Tensor]`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + Returns: + torch.LongTensor: Output semantic tokens. + """ + if semantic_generation_config is None: + raise ValueError("`semantic_generation_config` has to be provided") + + batch_size = input_ids.shape[0] + + max_input_semantic_length = semantic_generation_config.max_input_semantic_length + + input_ids = input_ids + semantic_generation_config.text_encoding_offset + + if attention_mask is not None: + input_ids = input_ids.masked_fill((1 - attention_mask).bool(), semantic_generation_config.text_pad_token) + + if history_prompt is not None: + semantic_history = history_prompt["semantic_prompt"][-max_input_semantic_length:] + semantic_history = nn.functional.pad( + semantic_history, + (0, max_input_semantic_length - len(semantic_history)), + value=semantic_generation_config.semantic_pad_token, + mode="constant", + ) + else: + semantic_history = torch.tensor( + [semantic_generation_config.semantic_pad_token] * max_input_semantic_length, dtype=torch.int + ).to(self.device) + + semantic_history = torch.repeat_interleave(semantic_history[None], batch_size, dim=0) + + infer_array = torch.tensor( + [[semantic_generation_config.semantic_infer_token]] * batch_size, dtype=torch.int + ).to(self.device) + + input_embeds = torch.cat( + [ + self.input_embeds_layer(input_ids[:, :max_input_semantic_length]) + + self.input_embeds_layer(semantic_history[:, : max_input_semantic_length + 1]), + self.input_embeds_layer(infer_array), + ], + dim=1, + ) + + tokens_to_suppress = list( + range(semantic_generation_config.semantic_vocab_size, semantic_generation_config.semantic_pad_token) + ) + tokens_to_suppress.extend( + list(range(semantic_generation_config.semantic_pad_token + 1, self.config.output_vocab_size)) + ) + + suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress, device=input_ids.device) + + min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p) + early_stopping_logits_processor = BarkEosPrioritizerLogitsProcessor( + eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p, device=input_ids.device + ) + + # pass input_ids in order to stay consistent with the transformers generate method even though it is not used + # (except to get the input seq_len - that's why we keep the first 257 tokens) + semantic_output = super().generate( + torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int, device=self.device), + input_embeds=input_embeds, + logits_processor=[suppress_tokens_logits_processor, early_stopping_logits_processor], + generation_config=semantic_generation_config, + **kwargs, + ) # size: 10048 + + # take the generated semantic tokens + semantic_output = semantic_output[:, max_input_semantic_length + 1 :] + + return semantic_output + + +@auto_docstring( + custom_intro=""" + Bark coarse acoustics model. + It shares the same architecture as the semantic (or text) model. It is a GPT-2 like autoregressive model with a + language modeling head on top. + """ +) +class BarkCoarseModel(BarkCausalModel): + base_model_prefix = "coarse_acoustics" + config: BarkCoarseConfig + + def preprocess_histories( + self, + max_coarse_history: int, + semantic_to_coarse_ratio: int, + batch_size: int, + semantic_generation_config: int, + codebook_size: int, + history_prompt: Optional[dict[str, torch.Tensor]] = None, + ): + """ + Preprocess the optional `Bark` speaker prompts before `self.generate`. + + Args: + max_coarse_history (`int`): + Maximum size of coarse tokens used. + semantic_to_coarse_ratio (`int`): + Ratio of semantic to coarse frequency + batch_size (`int`): + Batch size, i.e the number of samples. + semantic_generation_config (`BarkSemanticGenerationConfig`): + Generation config indicating how to generate the semantic tokens. + codebook_size (`int`): + Codebook channel size, i.e. the size of the output vocabulary per codebook channel. + history_prompt (`Optional[dict[str,torch.Tensor]]`): + Optional `Bark` speaker prompt. + Returns: Returns: + `tuple(torch.FloatTensor)`: + - **x_semantic_history** (`torch.FloatTensor` -- Processed semantic speaker prompt. + - **x_coarse_history** (`torch.FloatTensor`) -- Processed coarse speaker prompt. + """ + if history_prompt is not None: + x_semantic_history = torch.repeat_interleave(history_prompt["semantic_prompt"][None], batch_size, dim=0) + # clone to avoid modifying history_prompt.coarse_prompt + x_coarse_history = history_prompt["coarse_prompt"].clone() + + # offset x_coarse_history + if codebook_size is not None: + for n in range(1, x_coarse_history.shape[0]): + # offset + x_coarse_history[n, :] += codebook_size * n + + # flatten x_coarse_history + x_coarse_history = torch.transpose(x_coarse_history, 0, 1).reshape(-1) + + x_coarse_history = x_coarse_history + semantic_generation_config.semantic_vocab_size + + x_coarse_history = torch.repeat_interleave(x_coarse_history[None], batch_size, dim=0) + # e.g: after SEMANTIC_VOCAB_SIZE (10000), 1024 tokens dedicated to first codebook, 1024 next tokens + # dedicated to second codebook. + + max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio)) + # trim histories correctly + n_semantic_hist_provided = min( + [ + max_semantic_history, + x_semantic_history.shape[1] - x_semantic_history.shape[1] % 2, + int(np.floor(x_coarse_history.shape[1] / semantic_to_coarse_ratio)), + ] + ) + + n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio)) + + x_semantic_history = x_semantic_history[:, -n_semantic_hist_provided:].int() + x_coarse_history = x_coarse_history[:, -n_coarse_hist_provided:].int() + # bit of a hack for time alignment (sounds better) - from Bark original implementation + x_coarse_history = x_coarse_history[:, :-2] + + else: + # shape: (batch_size, 0) + x_semantic_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device) + x_coarse_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device) + + return x_semantic_history, x_coarse_history + + def generate( + self, + semantic_output: torch.Tensor, + semantic_generation_config: Optional[BarkSemanticGenerationConfig] = None, + coarse_generation_config: Optional[BarkCoarseGenerationConfig] = None, + codebook_size: int = 1024, + history_prompt: Optional[dict[str, torch.Tensor]] = None, + return_output_lengths: Optional[bool] = None, + **kwargs, + ) -> Union[torch.LongTensor, tuple[torch.LongTensor, torch.LongTensor]]: + """ + Generates coarse acoustics tokens from input text semantic tokens and an additional optional `Bark` speaker + prompt. + + Args: + semantic_output (`torch.Tensor` of shape (batch_size, seq_len), *optional*): + Input text semantic ids, i.e the output of `BarkSemanticModel.generate`. + semantic_generation_config (`BarkSemanticGenerationConfig`): + Generation config indicating how to generate the semantic tokens. + coarse_generation_config (`BarkCoarseGenerationConfig`): + Generation config indicating how to generate the coarse tokens. + codebook_size (`int`, *optional*, defaults to 1024): + Codebook channel size, i.e. the size of the output vocabulary per codebook channel. + history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*): + Optional `Bark` speaker prompt. + return_output_lengths (`bool`, *optional*): + Whether or not to return the output lengths. Useful when batching. + Returns: + By default: + torch.LongTensor: Output coarse acoustics tokens. + If `return_output_lengths=True`: + `Tuple(torch.Tensor, torch.Tensor): The output coarse acoustics tokens, and the length of each sample + of the batch. + """ + + if semantic_generation_config is None: + raise ValueError("`semantic_generation_config` has to be provided") + + if coarse_generation_config is None: + raise ValueError("`coarse_generation_config` has to be provided") + + max_coarse_input_length = coarse_generation_config.max_coarse_input_length + max_coarse_history = coarse_generation_config.max_coarse_history + sliding_window_len = coarse_generation_config.sliding_window_len + + # replace semantic_pad_token (eos_tok and pad_tok here) with coarse_semantic_pad_token i.e the pad_token + # used in the next model + semantic_output.masked_fill_( + semantic_output == semantic_generation_config.semantic_pad_token, + coarse_generation_config.coarse_semantic_pad_token, + ) + + semantic_to_coarse_ratio = ( + coarse_generation_config.coarse_rate_hz + / semantic_generation_config.semantic_rate_hz + * coarse_generation_config.n_coarse_codebooks + ) + max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio)) + + output_lengths = (semantic_output != coarse_generation_config.coarse_semantic_pad_token).sum(1) + output_lengths = torch.floor( + output_lengths * semantic_to_coarse_ratio / coarse_generation_config.n_coarse_codebooks + ) + output_lengths = torch.round(output_lengths * coarse_generation_config.n_coarse_codebooks).int() + + max_generated_len = torch.max(output_lengths).item() + + batch_size = semantic_output.shape[0] + + x_semantic_history, x_coarse = self.preprocess_histories( + history_prompt=history_prompt, + max_coarse_history=max_coarse_history, + semantic_to_coarse_ratio=semantic_to_coarse_ratio, + batch_size=batch_size, + semantic_generation_config=semantic_generation_config, + codebook_size=codebook_size, + ) + base_semantic_idx = x_semantic_history.shape[1] + + semantic_output = torch.hstack([x_semantic_history, semantic_output]) + + n_window_steps = int(np.ceil(max_generated_len / sliding_window_len)) + + total_generated_len = 0 + + len_coarse_history = x_coarse.shape[1] + + for _ in range(n_window_steps): + semantic_idx = base_semantic_idx + int(round(total_generated_len / semantic_to_coarse_ratio)) + + # pad from right side + input_coarse = semantic_output[:, np.max([0, semantic_idx - max_semantic_history]) :] + input_coarse = input_coarse[:, :max_coarse_input_length] + input_coarse = F.pad( + input_coarse, + (0, max_coarse_input_length - input_coarse.shape[-1]), + "constant", + coarse_generation_config.coarse_semantic_pad_token, + ) + + input_coarse = torch.hstack( + [ + input_coarse, + torch.tensor([[coarse_generation_config.coarse_infer_token]] * batch_size, device=self.device), + x_coarse[:, -max_coarse_history:], + ] + ) + + alternatingLogitsProcessor = AlternatingCodebooksLogitsProcessor( + input_coarse.shape[1], + semantic_generation_config.semantic_vocab_size, + codebook_size, + ) + + output_coarse = super().generate( + input_coarse, + logits_processor=[alternatingLogitsProcessor], + max_new_tokens=min(sliding_window_len, max_generated_len - total_generated_len), + generation_config=coarse_generation_config, + **kwargs, + ) + + input_coarse_len = input_coarse.shape[1] + + x_coarse = torch.hstack([x_coarse, output_coarse[:, input_coarse_len:]]) + total_generated_len = x_coarse.shape[1] - len_coarse_history + + del output_coarse + + coarse_output = x_coarse[:, len_coarse_history:] + + if return_output_lengths: + return coarse_output, output_lengths + + return coarse_output + + +@auto_docstring( + custom_intro=""" + Bark fine acoustics model. It is a non-causal GPT-like model with `config.n_codes_total` embedding layers and + language modeling heads, one for each codebook. + """ +) +class BarkFineModel(BarkPreTrainedModel): + base_model_prefix = "fine_acoustics" + config: BarkFineConfig + main_input_name = "codebook_idx" + + def __init__(self, config): + # non-causal gpt-like model with one embedding layer and one lm_head for each codebook of Encodec + super().__init__(config) + self.config = config + + # initialize a modified non causal GPT-like model + # note that for there is one embedding layer and one lm_head for each codebook of Encodec + self.input_embeds_layers = nn.ModuleList( + [nn.Embedding(config.input_vocab_size, config.hidden_size) for _ in range(config.n_codes_total)] + ) + self.position_embeds_layer = nn.Embedding(config.block_size, config.hidden_size) + + self.drop = nn.Dropout(config.dropout) + + self.layers = nn.ModuleList( + [BarkBlock(config, is_causal=False, layer_idx=i) for i in range(config.num_layers)] + ) + + self.layernorm_final = nn.LayerNorm(config.hidden_size) + + self.lm_heads = nn.ModuleList( + [ + nn.Linear(config.hidden_size, config.output_vocab_size, bias=False) + for _ in range(config.n_codes_given, config.n_codes_total) + ] + ) + self.gradient_checkpointing = False + self.n_codes_total = config.n_codes_total + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + # one embedding layers for each codebook + return self.input_embeds_layers + + def set_input_embeddings(self, new_embeddings): + # one embedding layers for each codebook + self.input_embeds_layers = new_embeddings + + def get_output_embeddings(self): + # one lm_head for each codebook + return self.lm_heads + + def set_output_embeddings(self, new_output_embeddings): + # one lm_head for each codebook + self.lm_heads = new_output_embeddings + + def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True): + old_embeddings_list = self.get_input_embeddings() + new_embeddings_list = nn.ModuleList( + [ + self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing) + for old_embeddings in old_embeddings_list + ] + ) + self.set_input_embeddings(new_embeddings_list) + new_num_tokens = new_embeddings_list[0].weight.shape[0] + + # if word embeddings are not tied, make sure that lm head is resized as well + if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: + old_lm_head_list = self.get_output_embeddings() + new_lm_head_list = nn.ModuleList( + [self._get_resized_lm_head(old_lm_head, new_num_tokens) for old_lm_head in old_lm_head_list] + ) + self.set_output_embeddings(new_lm_head_list) + + return self.get_input_embeddings() + + def resize_token_embeddings( + self, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + mean_resizing: bool = True, + ) -> nn.Embedding: + """ + Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. + + Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + + Arguments: + new_num_tokens (`int`, *optional*): + The number of new tokens in the embedding matrix. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just + returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. + pad_to_multiple_of (`int`, *optional*): + If set will pad the embedding matrix to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more + details about this, or help on choosing the correct value for resizing, refer to this guide: + https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + mean_resizing (`bool`): + Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and + covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`. + + Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models, + where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the + old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings. + Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html + + Return: + `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. + """ + model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + if new_num_tokens is None and pad_to_multiple_of is None: + return model_embeds + + # Update base model and current model config + self.config.output_vocab_size = model_embeds[0].weight.shape[0] + self.config.vocab_size = model_embeds[0].weight.shape[0] + self.output_vocab_size = model_embeds[0].weight.shape[0] + self.vocab_size = model_embeds[0].weight.shape[0] + + # Tie weights again if needed + self.tie_weights() + + return model_embeds + + def _tie_weights(self): + if getattr(self.config, "tie_word_embeddings", True): + self._tied_weights_keys = [] + output_embeddings = self.get_output_embeddings() + input_embeddings = self.get_input_embeddings() + + for i in range(self.config.n_codes_total - self.config.n_codes_given): + # self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight + self._tie_or_clone_weights(output_embeddings[i], input_embeddings[i + 1]) + self._tied_weights_keys.append(f"lm_heads.{i}.weight") + + def tie_weights(self): + """ + Tie the weights between the input embeddings list and the output embeddings list. + + If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the + weights instead. + """ + for module in self.modules(): + if hasattr(module, "_tie_weights"): + module._tie_weights() + + @auto_docstring + def forward( + self, + codebook_idx: int, # an additional idx corresponding to the id of the codebook that will be predicted + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + input_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], MaskedLMOutput]: + r""" + codebook_idx (`int`): + Index of the codebook that will be predicted. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + NOT IMPLEMENTED YET. + input_embeds (`torch.FloatTensor` of shape `(batch_size, input_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. If + `past_key_values` is used, optionally only the last `input_embeds` have to be input (see + `past_key_values`). This is useful if you want more control over how to convert `input_ids` indices into + associated vectors than the model's internal embedding lookup matrix. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + loss = None + if labels is not None: + raise NotImplementedError("Training is not implemented yet") + + if codebook_idx == 0: + raise ValueError("Cannot predict 0th codebook - 0th codebook should be predicted by the coarse model") + + if input_ids is not None and input_embeds is not None: + raise ValueError("You cannot specify both input_ids and input_embeds at the same time") + + if input_ids is None and input_embeds is None: + raise ValueError("You have to specify either input_ids or input_embeds") + + if input_ids is not None: + # the input_embeddings are the sum of the j previous codebooks embeddings before + # the current codebook_idx codebook + + # forward the GPT model itself + input_embeds = [ + input_embeds_layer(input_ids[:, :, i]).unsqueeze(-1) + for i, input_embeds_layer in enumerate(self.input_embeds_layers) + ] # token embeddings of shape (b, t, n_embd) + input_embeds = torch.cat(input_embeds, dim=-1) + input_embeds = input_embeds[:, :, :, : codebook_idx + 1].sum(dim=-1) + + input_shape = input_embeds.size()[:-1] + batch_size = input_embeds.shape[0] + seq_length = input_shape[1] + + device = input_ids.device if input_ids is not None else input_embeds.device + + if position_ids is None: + position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) # shape (1, seq_length) + + position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd) + + # Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + else: + # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length] + # from_seq_length is 1 to easily broadcast + attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1) + + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + + hidden_states = self.drop(input_embeds + position_embeds) + output_shape = input_shape + (hidden_states.size(-1),) + + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, block in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask[i], + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[1],) + + hidden_states = self.layernorm_final(hidden_states) + hidden_states = hidden_states.view(output_shape) + + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + logits = self.lm_heads[codebook_idx - self.config.n_codes_given](hidden_states) + + if not return_dict: + return tuple(v for v in [None, logits, all_hidden_states, all_self_attentions] if v is not None) + + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + @torch.no_grad() + def generate( + self, + coarse_output: torch.Tensor, + semantic_generation_config: Optional[BarkSemanticGenerationConfig] = None, + coarse_generation_config: Optional[BarkCoarseGenerationConfig] = None, + fine_generation_config: BarkFineGenerationConfig = None, + codebook_size: int = 1024, + history_prompt: Optional[dict[str, torch.Tensor]] = None, + **kwargs, + ) -> torch.LongTensor: + """ + Generates fine acoustics tokens from input coarse acoustics tokens and an additional optional `Bark` speaker + prompt. + + Args: + coarse_output (`torch.Tensor` of shape (batch_size, seq_len)): + Input coarse acoustics ids, i.e the output of `BarkCoarseModel.generate`. + semantic_generation_config (`BarkSemanticGenerationConfig`): + Generation config indicating how to generate the semantic tokens. + coarse_generation_config (`BarkCoarseGenerationConfig`): + Generation config indicating how to generate the coarse tokens. + fine_generation_config (`BarkFineGenerationConfig`): + Generation config indicating how to generate the fine tokens. + codebook_size (`int`, *optional*, defaults to 1024): + Codebook channel size, i.e. the size of the output vocabulary per codebook channel. + history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*): + Optional `Bark` speaker prompt. + Returns: + torch.LongTensor: Output fine acoustics tokens. + """ + if semantic_generation_config is None: + raise ValueError("`semantic_generation_config` has to be provided") + + if coarse_generation_config is None: + raise ValueError("`coarse_generation_config` has to be provided") + + if fine_generation_config is None: + raise ValueError("`fine_generation_config` has to be provided") + + # since we don't really use GenerationConfig through the fine model (autoencoder) + # and since only temperature is used from the classic GenerationConfig parameters + # manually impose the kwargs priority over the generation config + temperature = kwargs.get("temperature", fine_generation_config.temperature) + + max_fine_history_length = fine_generation_config.max_fine_history_length + max_fine_input_length = fine_generation_config.max_fine_input_length + + # shape: (batch, n_coarse_codebooks * seq_len) + # new_shape: (batch, seq_len, n_coarse_codebooks) + coarse_output = coarse_output.view(coarse_output.shape[0], -1, coarse_generation_config.n_coarse_codebooks) + + # brings ids into the range [0, codebook_size -1] + coarse_output = torch.remainder(coarse_output - semantic_generation_config.semantic_vocab_size, codebook_size) + batch_size = coarse_output.shape[0] + + if history_prompt is not None: + x_fine_history = torch.repeat_interleave(history_prompt["fine_prompt"].T[None], batch_size, dim=0) + # transpose to get to shape (seq_len, n_fine_codebooks) + else: + x_fine_history = None + + n_coarse = coarse_generation_config.n_coarse_codebooks + + # pad the last 6th codebooks + fine_input = F.pad( + coarse_output, + (0, fine_generation_config.n_fine_codebooks - n_coarse), + "constant", + codebook_size, + ) + + # prepend history if available (max max_fine_history_length) + if x_fine_history is not None: + fine_input = torch.cat([x_fine_history[:, -max_fine_history_length:, :], fine_input], dim=1) + + # len of the fine_history that has been added to fine_input + n_history = x_fine_history[:, -max_fine_history_length:, :].shape[1] + else: + n_history = 0 + + n_remove_from_end = 0 + # need to pad if too short (since non-causal model) + if fine_input.shape[1] < max_fine_input_length: + n_remove_from_end = max_fine_input_length - fine_input.shape[1] + fine_input = F.pad(fine_input, (0, 0, 0, n_remove_from_end), mode="constant", value=codebook_size) + + # we can be lazy about fractional loop and just keep overwriting codebooks. + # seems that coarse_output.shape[1] - (max_fine_input_length - n_history) is equal to minus n_remove_from_end + # So if we needed to pad because too short, n_loops is always 1 (because n_remove_from_end > 0) + # If not, we loop over at least twice. + + n_loops = (coarse_output.shape[1] - (max_fine_input_length - n_history)) / max_fine_history_length + n_loops = int(np.ceil(n_loops)) + n_loops = max(0, n_loops) + 1 + + for n_outer in range(n_loops): + start_idx = min([n_outer * max_fine_history_length, fine_input.shape[1] - max_fine_input_length]) + + start_fill_idx = min( + [n_history + n_outer * max_fine_history_length, fine_input.shape[1] - max_fine_history_length] + ) + rel_start_fill_idx = start_fill_idx - start_idx + input_buffer = fine_input[:, start_idx : start_idx + max_fine_input_length, :] + for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks): + logits = self.forward(n_inner, input_buffer).logits + if temperature is None or temperature == 1.0: + relevant_logits = logits[:, rel_start_fill_idx:, :codebook_size] + codebook_preds = torch.argmax(relevant_logits, -1) + else: + relevant_logits = logits[:, :, :codebook_size] / temperature + # apply softmax + probs = F.softmax(relevant_logits, dim=-1)[:, rel_start_fill_idx:max_fine_input_length] + # reshape to 2D: (batch_size, seq_len, codebook_size) -> (batch_size*seq_len, codebook_size) + probs = probs.reshape((-1, codebook_size)) + # multinomial then reshape : (batch_size*seq_len)-> (batch_size,seq_len) + codebook_preds = torch.multinomial(probs, num_samples=1).view(batch_size, -1) + codebook_preds = codebook_preds.to(torch.int32) + input_buffer[:, rel_start_fill_idx:, n_inner] = codebook_preds + del logits, codebook_preds + + # transfer into fine_input + for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks): + fine_input[ + :, start_fill_idx : start_fill_idx + (max_fine_input_length - rel_start_fill_idx), n_inner + ] = input_buffer[:, rel_start_fill_idx:, n_inner] + del input_buffer + + fine_input = fine_input.transpose(1, 2)[:, :, n_history:] + if n_remove_from_end > 0: + fine_input = fine_input[:, :, :-n_remove_from_end] + + if fine_input.shape[-1] != coarse_output.shape[-2]: + raise ValueError("input and output should have the same seq_len") + + return fine_input + + +@auto_docstring( + custom_intro=""" + The full Bark model, a text-to-speech model composed of 4 sub-models: + - [`BarkSemanticModel`] (also referred to as the 'text' model): a causal auto-regressive transformer model that + takes + as input tokenized text, and predicts semantic text tokens that capture the meaning of the text. + - [`BarkCoarseModel`] (also referred to as the 'coarse acoustics' model), also a causal autoregressive transformer, + that takes into input the results of the last model. It aims at regressing the first two audio codebooks necessary + to `encodec`. + - [`BarkFineModel`] (the 'fine acoustics' model), this time a non-causal autoencoder transformer, which iteratively + predicts the last codebooks based on the sum of the previous codebooks embeddings. + - having predicted all the codebook channels from the [`EncodecModel`], Bark uses it to decode the output audio + array. + + It should be noted that each of the first three modules can support conditional speaker embeddings to condition the + output sound according to specific predefined voice. + """ +) +class BarkModel(BarkPreTrainedModel): + config: BarkConfig + + def __init__(self, config): + super().__init__(config) + + self.semantic = BarkSemanticModel(config.semantic_config) + self.coarse_acoustics = BarkCoarseModel(config.coarse_acoustics_config) + self.fine_acoustics = BarkFineModel(config.fine_acoustics_config) + + self.codec_model = AutoModel.from_config(config.codec_config) + + self.config = config + + @classmethod + def can_generate(cls) -> bool: + # Bark has a unique model structure, where the external class (`BarkModel`) doesn't need to inherit from + # `GenerationMixin` (it has a non-standard generation method), but one of the internal models do + # (`BarkSemanticModel`). This means that the base `can_generate()` will return `False`, but we need to + # override it so as to do `GenerationConfig` handling in multiple parts of the codebase. + return True + + @property + def device(self) -> torch.device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + # for bark_model, device must be verified on its sub-models + # if has _hf_hook, has been offloaded so the device has to be found in the hook + if not hasattr(self.semantic, "_hf_hook"): + return get_parameter_device(self) + for module in self.semantic.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + + def enable_cpu_offload( + self, + accelerator_id: Optional[int] = 0, + **kwargs, + ): + r""" + Offloads all sub-models to CPU using accelerate, reducing memory usage with a low impact on performance. This + method moves one whole sub-model at a time to the accelerator when it is used, and the sub-model remains in accelerator until the next sub-model runs. + + Args: + accelerator_id (`int`, *optional*, defaults to 0): + accelerator id on which the sub-models will be loaded and offloaded. This argument is deprecated. + kwargs (`dict`, *optional*): + additional keyword arguments: + `gpu_id`: accelerator id on which the sub-models will be loaded and offloaded. + """ + if is_accelerate_available(): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate`.") + + gpu_id = kwargs.get("gpu_id", 0) + + if gpu_id != 0: + warnings.warn( + "The argument `gpu_id` is deprecated and will be removed in version 4.54.0 of Transformers. Please use `accelerator_id` instead.", + FutureWarning, + ) + accelerator_id = gpu_id + + device_type = "cuda" + if is_torch_accelerator_available(): + device_type = torch.accelerator.current_accelerator().type + device = torch.device(f"{device_type}:{accelerator_id}") + + torch_accelerator_module = getattr(torch, device_type) + if self.device.type != "cpu": + self.to("cpu") + torch_accelerator_module.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + # this layer is used outside the first forward pass of semantic so need to be loaded before semantic + self.semantic.input_embeds_layer, _ = cpu_offload_with_hook(self.semantic.input_embeds_layer, device) + + hook = None + for cpu_offloaded_model in [ + self.semantic, + self.coarse_acoustics, + self.fine_acoustics, + ]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + self.fine_acoustics_hook = hook + + _, hook = cpu_offload_with_hook(self.codec_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.codec_model_hook = hook + + def codec_decode(self, fine_output, output_lengths=None): + """Turn quantized audio codes into audio array using encodec.""" + + fine_output = fine_output.transpose(0, 1) + emb = self.codec_model.quantizer.decode(fine_output) + + if output_lengths is not None: + # encodec uses LSTMs which behaves differently with appended padding + # decoding with encodec takes around 0.1% of the total generation time + # to keep generation quality, we break batching + out = [sample[:, :l].unsqueeze(0) for (sample, l) in zip(emb, output_lengths)] + audio_arr = [self.codec_model.decoder(sample).squeeze() for sample in out] + else: + out = self.codec_model.decoder(emb) + audio_arr = out.squeeze(1) # squeeze the codebook dimension + + return audio_arr + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + history_prompt: Optional[dict[str, torch.Tensor]] = None, + return_output_lengths: Optional[bool] = None, + **kwargs, + ) -> torch.LongTensor: + """ + Generates audio from an input prompt and an additional optional `Bark` speaker prompt. + + Args: + input_ids (`Optional[torch.Tensor]` of shape (batch_size, seq_len), *optional*): + Input ids. Will be truncated up to 256 tokens. Note that the output audios will be as long as the + longest generation among the batch. + history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*): + Optional `Bark` speaker prompt. Note that for now, this model takes only one speaker prompt per batch. + kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments are of two types: + + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model. + - With a *semantic_*, *coarse_*, *fine_* prefix, they will be input for the `generate` method of the + semantic, coarse and fine respectively. It has the priority over the keywords without a prefix. + + This means you can, for example, specify a generation strategy for all sub-models except one. + return_output_lengths (`bool`, *optional*): + Whether or not to return the waveform lengths. Useful when batching. + Returns: + By default: + - **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform. + When `return_output_lengths=True`: + Returns a tuple made of: + - **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform. + - **output_lengths** (`torch.Tensor` of shape (batch_size)): The length of each waveform in the batch + Example: + + ```python + >>> from transformers import AutoProcessor, BarkModel + + >>> processor = AutoProcessor.from_pretrained("suno/bark-small") + >>> model = BarkModel.from_pretrained("suno/bark-small") + + >>> # To add a voice preset, you can pass `voice_preset` to `BarkProcessor.__call__(...)` + >>> voice_preset = "v2/en_speaker_6" + + >>> inputs = processor("Hello, my dog is cute, I need him in my life", voice_preset=voice_preset) + + >>> audio_array = model.generate(**inputs, semantic_max_new_tokens=100) + >>> audio_array = audio_array.cpu().numpy().squeeze() + ``` + """ + # TODO (joao):workaround until nested generation config is compatible with PreTrained Model + # todo: dict + semantic_generation_config = BarkSemanticGenerationConfig(**self.generation_config.semantic_config) + coarse_generation_config = BarkCoarseGenerationConfig(**self.generation_config.coarse_acoustics_config) + fine_generation_config = BarkFineGenerationConfig(**self.generation_config.fine_acoustics_config) + + kwargs_semantic = { + # if "attention_mask" is set, it should not be passed to CoarseModel and FineModel + "attention_mask": kwargs.pop("attention_mask", None), + "min_eos_p": kwargs.pop("min_eos_p", None), + } + kwargs_coarse = {} + kwargs_fine = {} + for key, value in kwargs.items(): + if key.startswith("semantic_"): + key = key[len("semantic_") :] + kwargs_semantic[key] = value + elif key.startswith("coarse_"): + key = key[len("coarse_") :] + kwargs_coarse[key] = value + elif key.startswith("fine_"): + key = key[len("fine_") :] + kwargs_fine[key] = value + else: + # If the key is already in a specific config, then it's been set with a + # submodules specific value and we don't override + if key not in kwargs_semantic: + kwargs_semantic[key] = value + if key not in kwargs_coarse: + kwargs_coarse[key] = value + if key not in kwargs_fine: + kwargs_fine[key] = value + + # 1. Generate from the semantic model + if "generation_config" in kwargs_semantic: + kwargs_semantic.pop("generation_config") + semantic_output = self.semantic.generate( + input_ids, + history_prompt=history_prompt, + semantic_generation_config=semantic_generation_config, + **kwargs_semantic, + ) + + # 2. Generate from the coarse model + if "generation_config" in kwargs_coarse: + kwargs_coarse.pop("generation_config") + coarse_output = self.coarse_acoustics.generate( + semantic_output, + history_prompt=history_prompt, + semantic_generation_config=semantic_generation_config, + coarse_generation_config=coarse_generation_config, + codebook_size=self.generation_config.codebook_size, + return_output_lengths=return_output_lengths, + **kwargs_coarse, + ) + + output_lengths = None + if return_output_lengths: + coarse_output, output_lengths = coarse_output + # (batch_size, seq_len*coarse_codebooks) -> (batch_size, seq_len) + output_lengths = output_lengths // coarse_generation_config.n_coarse_codebooks + + # 3. "generate" from the fine model + if "generation_config" in kwargs_fine: + kwargs_fine.pop("generation_config") + output = self.fine_acoustics.generate( + coarse_output, + history_prompt=history_prompt, + semantic_generation_config=semantic_generation_config, + coarse_generation_config=coarse_generation_config, + fine_generation_config=fine_generation_config, + codebook_size=self.generation_config.codebook_size, + **kwargs_fine, + ) + + if getattr(self, "fine_acoustics_hook", None) is not None: + # Manually offload fine_acoustics to CPU + # and load codec_model to GPU + # since bark doesn't use codec_model forward pass + self.fine_acoustics_hook.offload() + self.codec_model = self.codec_model.to(self.device) + + # 4. Decode the output and generate audio array + audio = self.codec_decode(output, output_lengths) + + if getattr(self, "codec_model_hook", None) is not None: + # Offload codec_model to CPU + self.codec_model_hook.offload() + + if return_output_lengths: + output_lengths = [len(sample) for sample in audio] + audio = nn.utils.rnn.pad_sequence(audio, batch_first=True, padding_value=0) + return audio, output_lengths + + return audio + + def tie_weights(self): + """ + Tie the weights between the input embeddings list and the output embeddings list. + + If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the + weights instead. + """ + for module in self.modules(): + if hasattr(module, "_tie_weights"): + module._tie_weights() + + +__all__ = [ + "BarkFineModel", + "BarkSemanticModel", + "BarkCoarseModel", + "BarkModel", + "BarkPreTrainedModel", + "BarkCausalModel", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/processing_bark.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/processing_bark.py new file mode 100644 index 0000000000000000000000000000000000000000..155f15cced201b57835fc3037c4f1dfb4a110ffa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/processing_bark.py @@ -0,0 +1,340 @@ +# coding=utf-8 +# Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Bark +""" + +import json +import os +from typing import Optional + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import BatchEncoding +from ...utils import logging +from ...utils.hub import cached_file +from ..auto import AutoTokenizer + + +logger = logging.get_logger(__name__) + + +class BarkProcessor(ProcessorMixin): + r""" + Constructs a Bark processor which wraps a text tokenizer and optional Bark voice presets into a single processor. + + Args: + tokenizer ([`PreTrainedTokenizer`]): + An instance of [`PreTrainedTokenizer`]. + speaker_embeddings (`dict[dict[str]]`, *optional*): + Optional nested speaker embeddings dictionary. The first level contains voice preset names (e.g + `"en_speaker_4"`). The second level contains `"semantic_prompt"`, `"coarse_prompt"` and `"fine_prompt"` + embeddings. The values correspond to the path of the corresponding `np.ndarray`. See + [here](https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c) for + a list of `voice_preset_names`. + + """ + + tokenizer_class = "AutoTokenizer" + attributes = ["tokenizer"] + + preset_shape = { + "semantic_prompt": 1, # 1D array of shape (X,) + "coarse_prompt": 2, # 2D array of shape (2,X) + "fine_prompt": 2, # 2D array of shape (8,X) + } + + def __init__(self, tokenizer, speaker_embeddings=None): + super().__init__(tokenizer) + + self.speaker_embeddings = speaker_embeddings + + @classmethod + def from_pretrained( + cls, pretrained_processor_name_or_path, speaker_embeddings_dict_path="speaker_embeddings_path.json", **kwargs + ): + r""" + Instantiate a Bark processor associated with a pretrained model. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained [`BarkProcessor`] hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a processor saved using the [`~BarkProcessor.save_pretrained`] + method, e.g., `./my_model_directory/`. + speaker_embeddings_dict_path (`str`, *optional*, defaults to `"speaker_embeddings_path.json"`): + The name of the `.json` file containing the speaker_embeddings dictionary located in + `pretrained_model_name_or_path`. If `None`, no speaker_embeddings is loaded. + **kwargs + Additional keyword arguments passed along to both + [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`]. + """ + + if speaker_embeddings_dict_path is not None: + speaker_embeddings_path = cached_file( + pretrained_processor_name_or_path, + speaker_embeddings_dict_path, + subfolder=kwargs.pop("subfolder", None), + cache_dir=kwargs.pop("cache_dir", None), + force_download=kwargs.pop("force_download", False), + proxies=kwargs.pop("proxies", None), + resume_download=kwargs.pop("resume_download", None), + local_files_only=kwargs.pop("local_files_only", False), + token=kwargs.pop("use_auth_token", None), + revision=kwargs.pop("revision", None), + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) + if speaker_embeddings_path is None: + logger.warning( + f"""`{os.path.join(pretrained_processor_name_or_path, speaker_embeddings_dict_path)}` does not exists + , no preloaded speaker embeddings will be used - Make sure to provide a correct path to the json + dictionary if wanted, otherwise set `speaker_embeddings_dict_path=None`.""" + ) + speaker_embeddings = None + else: + with open(speaker_embeddings_path) as speaker_embeddings_json: + speaker_embeddings = json.load(speaker_embeddings_json) + else: + speaker_embeddings = None + + if speaker_embeddings is not None: + if "repo_or_path" in speaker_embeddings: + speaker_embeddings["repo_or_path"] = pretrained_processor_name_or_path + tokenizer = AutoTokenizer.from_pretrained(pretrained_processor_name_or_path, **kwargs) + + return cls(tokenizer=tokenizer, speaker_embeddings=speaker_embeddings) + + def save_pretrained( + self, + save_directory, + speaker_embeddings_dict_path="speaker_embeddings_path.json", + speaker_embeddings_directory="speaker_embeddings", + push_to_hub: bool = False, + **kwargs, + ): + """ + Saves the attributes of this processor (tokenizer...) in the specified directory so that it can be reloaded + using the [`~BarkProcessor.from_pretrained`] method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the tokenizer files and the speaker embeddings will be saved (directory will be created + if it does not exist). + speaker_embeddings_dict_path (`str`, *optional*, defaults to `"speaker_embeddings_path.json"`): + The name of the `.json` file that will contains the speaker_embeddings nested path dictionary, if it + exists, and that will be located in `pretrained_model_name_or_path/speaker_embeddings_directory`. + speaker_embeddings_directory (`str`, *optional*, defaults to `"speaker_embeddings/"`): + The name of the folder in which the speaker_embeddings arrays will be saved. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs: + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if self.speaker_embeddings is not None: + os.makedirs(os.path.join(save_directory, speaker_embeddings_directory, "v2"), exist_ok=True) + + embeddings_dict = {} + + embeddings_dict["repo_or_path"] = save_directory + + for prompt_key in self.available_voice_presets: + voice_preset = self._load_voice_preset(prompt_key) + + tmp_dict = {} + for key in self.speaker_embeddings[prompt_key]: + np.save( + os.path.join( + embeddings_dict["repo_or_path"], speaker_embeddings_directory, f"{prompt_key}_{key}" + ), + voice_preset[key], + allow_pickle=False, + ) + tmp_dict[key] = os.path.join(speaker_embeddings_directory, f"{prompt_key}_{key}.npy") + + embeddings_dict[prompt_key] = tmp_dict + + with open(os.path.join(save_directory, speaker_embeddings_dict_path), "w") as fp: + json.dump(embeddings_dict, fp) + + super().save_pretrained(save_directory, push_to_hub, **kwargs) + + def _load_voice_preset(self, voice_preset: Optional[str] = None, **kwargs): + voice_preset_paths = self.speaker_embeddings[voice_preset] + + voice_preset_dict = {} + for key in ["semantic_prompt", "coarse_prompt", "fine_prompt"]: + if key not in voice_preset_paths: + raise ValueError( + f"Voice preset unrecognized, missing {key} as a key in self.speaker_embeddings[{voice_preset}]." + ) + + path = cached_file( + self.speaker_embeddings.get("repo_or_path", "/"), + voice_preset_paths[key], + subfolder=kwargs.pop("subfolder", None), + cache_dir=kwargs.pop("cache_dir", None), + force_download=kwargs.pop("force_download", False), + proxies=kwargs.pop("proxies", None), + resume_download=kwargs.pop("resume_download", None), + local_files_only=kwargs.pop("local_files_only", False), + token=kwargs.pop("use_auth_token", None), + revision=kwargs.pop("revision", None), + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) + if path is None: + raise ValueError( + f"""`{os.path.join(self.speaker_embeddings.get("repo_or_path", "/"), voice_preset_paths[key])}` does not exists + , no preloaded voice preset will be used - Make sure to provide correct paths to the {voice_preset} + embeddings.""" + ) + + voice_preset_dict[key] = np.load(path) + + return voice_preset_dict + + def _validate_voice_preset_dict(self, voice_preset: Optional[dict] = None): + for key in ["semantic_prompt", "coarse_prompt", "fine_prompt"]: + if key not in voice_preset: + raise ValueError(f"Voice preset unrecognized, missing {key} as a key.") + + if not isinstance(voice_preset[key], np.ndarray): + raise TypeError(f"{key} voice preset must be a {str(self.preset_shape[key])}D ndarray.") + + if len(voice_preset[key].shape) != self.preset_shape[key]: + raise ValueError(f"{key} voice preset must be a {str(self.preset_shape[key])}D ndarray.") + + @property + def available_voice_presets(self) -> list: + """ + Returns a list of available voice presets. + + Returns: + `list[str]`: A list of voice preset names. + """ + if self.speaker_embeddings is None: + return [] + + voice_presets = list(self.speaker_embeddings.keys()) + if "repo_or_path" in voice_presets: + voice_presets.remove("repo_or_path") + return voice_presets + + def _verify_speaker_embeddings(self, remove_unavailable: bool = True): + # check which actually downloaded properly / are available + unavailable_keys = [] + if self.speaker_embeddings is not None: + for voice_preset in self.available_voice_presets: + try: + voice_preset_dict = self._load_voice_preset(voice_preset) + except ValueError: + # error from `_load_voice_preset` of path not existing + unavailable_keys.append(voice_preset) + continue + self._validate_voice_preset_dict(voice_preset_dict) + + if unavailable_keys: + logger.warning( + f"The following {len(unavailable_keys)} speaker embeddings are not available: {unavailable_keys} " + "If you would like to use them, please check the paths or try downloading them again." + ) + + if remove_unavailable: + for voice_preset in unavailable_keys: + del self.speaker_embeddings[voice_preset] + + def __call__( + self, + text=None, + voice_preset=None, + return_tensors="pt", + max_length=256, + add_special_tokens=False, + return_attention_mask=True, + return_token_type_ids=False, + **kwargs, + ) -> BatchEncoding: + """ + Main method to prepare for the model one or several sequences(s). This method forwards the `text` and `kwargs` + arguments to the AutoTokenizer's [`~AutoTokenizer.__call__`] to encode the text. The method also proposes a + voice preset which is a dictionary of arrays that conditions `Bark`'s output. `kwargs` arguments are forwarded + to the tokenizer and to `cached_file` method if `voice_preset` is a valid filename. + + Args: + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + voice_preset (`str`, `dict[np.ndarray]`): + The voice preset, i.e the speaker embeddings. It can either be a valid voice_preset name, e.g + `"en_speaker_1"`, or directly a dictionary of `np.ndarray` embeddings for each submodel of `Bark`. Or + it can be a valid file name of a local `.npz` single voice preset containing the keys + `"semantic_prompt"`, `"coarse_prompt"` and `"fine_prompt"`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + + Returns: + [`BatchEncoding`]: A [`BatchEncoding`] object containing the output of the `tokenizer`. + If a voice preset is provided, the returned object will include a `"history_prompt"` key + containing a [`BatchFeature`], i.e the voice preset with the right tensors type. + """ + if voice_preset is not None and not isinstance(voice_preset, dict): + if ( + isinstance(voice_preset, str) + and self.speaker_embeddings is not None + and voice_preset in self.speaker_embeddings + ): + voice_preset = self._load_voice_preset(voice_preset) + + else: + if isinstance(voice_preset, str) and not voice_preset.endswith(".npz"): + voice_preset = voice_preset + ".npz" + + voice_preset = np.load(voice_preset) + + if voice_preset is not None: + self._validate_voice_preset_dict(voice_preset, **kwargs) + voice_preset = BatchFeature(data=voice_preset, tensor_type=return_tensors) + + encoded_text = self.tokenizer( + text, + return_tensors=return_tensors, + padding="max_length", + max_length=max_length, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + add_special_tokens=add_special_tokens, + **kwargs, + ) + + if voice_preset is not None: + encoded_text["history_prompt"] = voice_preset + + return encoded_text + + +__all__ = ["BarkProcessor"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef22794dde26e6275ba0ae850f6042ff6a451fd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_bert import * + from .modeling_bert import * + from .modeling_flax_bert import * + from .modeling_tf_bert import * + from .tokenization_bert import * + from .tokenization_bert_fast import * + from .tokenization_bert_tf import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/configuration_bert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/configuration_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..e7e51d3295ef5698ea6b060002a9d18a10587986 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/configuration_bert.py @@ -0,0 +1,154 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BERT model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class BertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to + instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the BERT + [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import BertConfig, BertModel + + >>> # Initializing a BERT google-bert/bert-base-uncased style configuration + >>> configuration = BertConfig() + + >>> # Initializing a model (with random weights) from the google-bert/bert-base-uncased style configuration + >>> model = BertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "bert" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + + +class BertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ] + ) + + +__all__ = ["BertConfig", "BertOnnxConfig"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..b9238d8bb0713e72fd4d44c70e2404d475cd4048 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py @@ -0,0 +1,1801 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_bert import BertConfig + + +logger = logging.get_logger(__name__) + + +def load_tf_weights_in_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + self.layer_idx = layer_idx + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + + is_updated = False + is_cross_attention = encoder_hidden_states is not None + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values + else: + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if past_key_values is not None: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, attention_probs + + +class BertSdpaSelfAttention(BertSelfAttention): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) + self.dropout_prob = config.attention_probs_dropout_prob + + # Adapted from BertSelfAttention + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + past_key_values, + output_attentions, + cache_position, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + + is_updated = False + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values + else: + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + return attn_output, None + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +BERT_SELF_ATTENTION_CLASSES = { + "eager": BertSelfAttention, + "sdpa": BertSdpaSelfAttention, +} + + +class BertAttention(nn.Module): + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__() + self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, + ) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(GradientCheckpointingLayer): + def __init__(self, config, layer_idx=None): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config, layer_idx=layer_idx) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = BertAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + past_key_values=past_key_values, + cache_position=cache_position, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and self.config.is_decoder and past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + + if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +@auto_docstring +class BertPreTrainedModel(PreTrainedModel): + config: BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = "bert" + supports_gradient_checkpointing = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, BertLMPredictionHead): + module.bias.data.zero_() + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`BertForPreTraining`]. + """ +) +class BertForPreTrainingOutput(ModelOutput): + r""" + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: Optional[torch.FloatTensor] = None + seq_relationship_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@auto_docstring( + custom_intro=""" + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ +) +class BertModel(BertPreTrainedModel): + _no_split_modules = ["BertEmbeddings", "BertLayer"] + + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks and attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + + if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next + sentence prediction (classification)` head. + """ +) +class BertForPreTraining(BertPreTrainedModel): + _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], BertForPreTrainingOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence + pair (see `input_ids` docstring) Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Example: + + ```python + >>> from transformers import AutoTokenizer, BertForPreTraining + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + >>> model = BertForPreTraining.from_pretrained("google-bert/bert-base-uncased") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + Bert Model with a `language modeling` head on top for CLM fine-tuning. + """ +) +class BertLMHeadModel(BertPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **loss_kwargs, + ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@auto_docstring +class BertForMaskedLM(BertPreTrainedModel): + _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + @classmethod + def can_generate(cls) -> bool: + """ + Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a + `prepare_inputs_for_generation` method. + """ + return False + + +@auto_docstring( + custom_intro=""" + Bert Model with a `next sentence prediction (classification)` head on top. + """ +) +class BertForNextSentencePrediction(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertOnlyNSPHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[torch.Tensor], NextSentencePredictorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Example: + + ```python + >>> from transformers import AutoTokenizer, BertForNextSentencePrediction + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + >>> model = BertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ``` + """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" + " `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """ +) +class BertForSequenceClassification(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.bert = BertModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class BertForMultipleChoice(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class BertForTokenClassification(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring +class BertForQuestionAnswering(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "BertForMaskedLM", + "BertForMultipleChoice", + "BertForNextSentencePrediction", + "BertForPreTraining", + "BertForQuestionAnswering", + "BertForSequenceClassification", + "BertForTokenClassification", + "BertLayer", + "BertLMHeadModel", + "BertModel", + "BertPreTrainedModel", + "load_tf_weights_in_bert", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_flax_bert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_flax_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..37828642eb4eb025973952ce0750a66e7e7693ea --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_flax_bert.py @@ -0,0 +1,1727 @@ +# coding=utf-8 +# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxBaseModelOutputWithPooling, + FlaxBaseModelOutputWithPoolingAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxMaskedLMOutput, + FlaxMultipleChoiceModelOutput, + FlaxNextSentencePredictorOutput, + FlaxQuestionAnsweringModelOutput, + FlaxSequenceClassifierOutput, + FlaxTokenClassifierOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_bert import BertConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased" +_CONFIG_FOR_DOC = "BertConfig" + +remat = nn_partitioning.remat + + +@flax.struct.dataclass +class FlaxBertForPreTrainingOutput(ModelOutput): + """ + Output type of [`BertForPreTraining`]. + + Args: + prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (`jnp.ndarray` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + prediction_logits: jnp.ndarray = None + seq_relationship_logits: jnp.ndarray = None + hidden_states: Optional[tuple[jnp.ndarray]] = None + attentions: Optional[tuple[jnp.ndarray]] = None + + +BERT_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`BertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. + +""" + +BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + head_mask (`numpy.ndarray` of shape `({0})`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + +""" + + +class FlaxBertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + config: BertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.position_embeddings = nn.Embed( + self.config.max_position_embeddings, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.token_type_embeddings = nn.Embed( + self.config.type_vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): + # Embed + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + position_embeds = self.position_embeddings(position_ids.astype("i4")) + token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) + + # Sum all embeddings + hidden_states = inputs_embeds + token_type_embeddings + position_embeds + + # Layer Norm + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxBertSelfAttention(nn.Module): + config: BertConfig + causal: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.head_dim = self.config.hidden_size // self.config.num_attention_heads + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " + " : {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) + + @nn.compact + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + key_value_states: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic=True, + output_attentions: bool = False, + ): + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.query(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.key(key_value_states) + value_states = self.value(key_value_states) + else: + # self_attention + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxBertSelfOutput(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class FlaxBertAttention(nn.Module): + config: BertConfig + causal: bool = False + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.self = FlaxBertSelfAttention(self.config, causal=self.causal, dtype=self.dtype) + self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + key_value_states=None, + init_cache=False, + deterministic=True, + output_attentions: bool = False, + ): + # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) + # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable + # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) + attn_outputs = self.self( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=key_value_states, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +class FlaxBertIntermediate(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class FlaxBertOutput(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + attention_output) + return hidden_states + + +class FlaxBertLayer(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxBertAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype) + self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype) + self.output = FlaxBertOutput(self.config, dtype=self.dtype) + if self.config.add_cross_attention: + self.crossattention = FlaxBertAttention(self.config, causal=False, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + ): + # Self Attention + attention_outputs = self.attention( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = attention_outputs[0] + + # Cross-Attention Block + if encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=encoder_hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + + hidden_states = self.intermediate(attention_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + if encoder_hidden_states is not None: + outputs += (cross_attention_outputs[1],) + return outputs + + +class FlaxBertLayerCollection(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + if self.gradient_checkpointing: + FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7)) + self.layers = [ + FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # Check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.shape[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for " + f" {head_mask.shape[0]}." + ) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + attention_mask, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +class FlaxBertEncoder(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.layer = FlaxBertLayerCollection( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class FlaxBertPooler(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + cls_hidden_state = hidden_states[:, 0] + cls_hidden_state = self.dense(cls_hidden_state) + return nn.tanh(cls_hidden_state) + + +class FlaxBertPredictionHeadTransform(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) + self.activation = ACT2FN[self.config.hidden_act] + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return self.LayerNorm(hidden_states) + + +class FlaxBertLMPredictionHead(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype) + self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) + self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) + + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.transform(hidden_states) + + if shared_embedding is not None: + hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + hidden_states = self.decoder(hidden_states) + + bias = jnp.asarray(self.bias, self.dtype) + hidden_states += bias + return hidden_states + + +class FlaxBertOnlyMLMHead(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding) + return hidden_states + + +class FlaxBertOnlyNSPHead(nn.Module): + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.seq_relationship = nn.Dense(2, dtype=self.dtype) + + def __call__(self, pooled_output): + return self.seq_relationship(pooled_output) + + +class FlaxBertPreTrainingHeads(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype) + self.seq_relationship = nn.Dense(2, dtype=self.dtype) + + def __call__(self, hidden_states, pooled_output, shared_embedding=None): + prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class FlaxBertPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + module_class: nn.Module = None + + def __init__( + self, + config: BertConfig, + input_shape: tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + gradient_checkpointing: bool = False, + **kwargs, + ): + module = self.module_class( + config=config, + dtype=dtype, + gradient_checkpointing=gradient_checkpointing, + **kwargs, + ) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + token_type_ids = jnp.zeros_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) + attention_mask = jnp.ones_like(input_ids) + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False + ) + + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + params: Optional[dict] = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + past_key_values: Optional[dict] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # init input tensors if not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + if head_mask is None: + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + if self.config.add_cross_attention: + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxBertAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + else: + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + ) + + return outputs + + +class FlaxBertModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True + gradient_checkpointing: bool = False + + def setup(self): + self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxBertEncoder( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.pooler = FlaxBertPooler(self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # make sure `token_type_ids` is correctly initialized when not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + # make sure `position_ids` is correctly initialized when not passed + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + hidden_states = self.embeddings( + input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic + ) + outputs = self.encoder( + hidden_states, + attention_mask, + head_mask=head_mask, + deterministic=deterministic, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + pooled = self.pooler(hidden_states) if self.add_pooling_layer else None + + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] + + return FlaxBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, +) +class FlaxBertModel(FlaxBertPreTrainedModel): + module_class = FlaxBertModule + + +append_call_sample_docstring(FlaxBertModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC) + + +class FlaxBertForPreTrainingModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.bert = FlaxBertModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.tie_word_embeddings: + shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + hidden_states = outputs[0] + pooled_output = outputs[1] + + prediction_scores, seq_relationship_score = self.cls( + hidden_states, pooled_output, shared_embedding=shared_embedding + ) + + if not return_dict: + return (prediction_scores, seq_relationship_score) + outputs[2:] + + return FlaxBertForPreTrainingOutput( + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next + sentence prediction (classification)` head. + """, + BERT_START_DOCSTRING, +) +class FlaxBertForPreTraining(FlaxBertPreTrainedModel): + module_class = FlaxBertForPreTrainingModule + + +FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxBertForPreTraining + + >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + >>> model = FlaxBertForPreTraining.from_pretrained("google-bert/bert-base-uncased") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + ``` +""" + +overwrite_call_docstring( + FlaxBertForPreTraining, + BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_PRETRAINING_DOCSTRING, +) +append_replace_return_docstrings( + FlaxBertForPreTraining, output_type=FlaxBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC +) + + +class FlaxBertForMaskedLMModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.bert = FlaxBertModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.cls(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING) +class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): + module_class = FlaxBertForMaskedLMModule + + +append_call_sample_docstring(FlaxBertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC) + + +class FlaxBertForNextSentencePredictionModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.bert = FlaxBertModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + seq_relationship_scores = self.cls(pooled_output) + + if not return_dict: + return (seq_relationship_scores,) + outputs[2:] + + return FlaxNextSentencePredictorOutput( + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """Bert Model with a `next sentence prediction (classification)` head on top.""", + BERT_START_DOCSTRING, +) +class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel): + module_class = FlaxBertForNextSentencePredictionModule + + +FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxBertForNextSentencePrediction + + >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + >>> model = FlaxBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="jax") + + >>> outputs = model(**encoding) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ``` +""" + + +overwrite_call_docstring( + FlaxBertForNextSentencePrediction, + BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING, +) +append_replace_return_docstrings( + FlaxBertForNextSentencePrediction, output_type=FlaxNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC +) + + +class FlaxBertForSequenceClassificationModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.bert = FlaxBertModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) + self.classifier = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + ) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + if not return_dict: + return (logits,) + outputs[2:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + BERT_START_DOCSTRING, +) +class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel): + module_class = FlaxBertForSequenceClassificationModule + + +append_call_sample_docstring( + FlaxBertForSequenceClassification, + _CHECKPOINT_FOR_DOC, + FlaxSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxBertForMultipleChoiceModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.bert = FlaxBertModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense(1, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + num_choices = input_ids.shape[1] + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None + token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None + position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None + + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.reshape(-1, num_choices) + + if not return_dict: + return (reshaped_logits,) + outputs[2:] + + return FlaxMultipleChoiceModelOutput( + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + BERT_START_DOCSTRING, +) +class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel): + module_class = FlaxBertForMultipleChoiceModule + + +overwrite_call_docstring( + FlaxBertForMultipleChoice, BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") +) +append_call_sample_docstring( + FlaxBertForMultipleChoice, _CHECKPOINT_FOR_DOC, FlaxMultipleChoiceModelOutput, _CONFIG_FOR_DOC +) + + +class FlaxBertForTokenClassificationModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.bert = FlaxBertModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.classifier(hidden_states) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxTokenClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + BERT_START_DOCSTRING, +) +class FlaxBertForTokenClassification(FlaxBertPreTrainedModel): + module_class = FlaxBertForTokenClassificationModule + + +append_call_sample_docstring( + FlaxBertForTokenClassification, _CHECKPOINT_FOR_DOC, FlaxTokenClassifierOutput, _CONFIG_FOR_DOC +) + + +class FlaxBertForQuestionAnsweringModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.bert = FlaxBertModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + return (start_logits, end_logits) + outputs[1:] + + return FlaxQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BERT_START_DOCSTRING, +) +class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel): + module_class = FlaxBertForQuestionAnsweringModule + + +append_call_sample_docstring( + FlaxBertForQuestionAnswering, + _CHECKPOINT_FOR_DOC, + FlaxQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxBertForCausalLMModule(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.bert = FlaxBertModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + token_type_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.cls(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxCausalLMOutputWithCrossAttentions( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for + autoregressive tasks. + """, + BERT_START_DOCSTRING, +) +class FlaxBertForCausalLM(FlaxBertPreTrainedModel): + module_class = FlaxBertForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxBertForCausalLM, + _CHECKPOINT_FOR_DOC, + FlaxCausalLMOutputWithCrossAttentions, + _CONFIG_FOR_DOC, +) + + +__all__ = [ + "FlaxBertForCausalLM", + "FlaxBertForMaskedLM", + "FlaxBertForMultipleChoice", + "FlaxBertForNextSentencePrediction", + "FlaxBertForPreTraining", + "FlaxBertForQuestionAnswering", + "FlaxBertForSequenceClassification", + "FlaxBertForTokenClassification", + "FlaxBertModel", + "FlaxBertPreTrainedModel", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_tf_bert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_tf_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..1ca82f9f18200fafc165d851da0f24591b259aa8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_tf_bert.py @@ -0,0 +1,2125 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 BERT model.""" + +from __future__ import annotations + +import math +import warnings +from dataclasses import dataclass + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, + TFCausalLMOutputWithCrossAttentions, + TFMaskedLMOutput, + TFMultipleChoiceModelOutput, + TFNextSentencePredictorOutput, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFMultipleChoiceLoss, + TFNextSentencePredictionLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_bert import BertConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased" +_CONFIG_FOR_DOC = "BertConfig" + +# TokenClassification docstring +_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english" +_TOKEN_CLASS_EXPECTED_OUTPUT = ( + "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] " +) +_TOKEN_CLASS_EXPECTED_LOSS = 0.01 + +# QuestionAnswering docstring +_CHECKPOINT_FOR_QA = "ydshieh/bert-base-cased-squad2" +_QA_EXPECTED_OUTPUT = "'a nice puppet'" +_QA_EXPECTED_LOSS = 7.41 +_QA_TARGET_START_INDEX = 14 +_QA_TARGET_END_INDEX = 15 + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ydshieh/bert-base-uncased-yelp-polarity" +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'" +_SEQ_CLASS_EXPECTED_LOSS = 0.01 + + +class TFBertPreTrainingLoss: + """ + Loss function suitable for BERT-like pretraining, that is, the task of pretraining a language model by combining + NSP + MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss + computation. + """ + + def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor: + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE) + + # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway + unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0]) + # make sure only labels that are not equal to -100 + # are taken into account for the loss computation + lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype) + masked_lm_losses = unmasked_lm_losses * lm_loss_mask + reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask) + + # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway + unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels["next_sentence_label"]), y_pred=logits[1]) + ns_loss_mask = tf.cast(labels["next_sentence_label"] != -100, dtype=unmasked_ns_loss.dtype) + masked_ns_loss = unmasked_ns_loss * ns_loss_mask + + reduced_masked_ns_loss = tf.reduce_sum(masked_ns_loss) / tf.reduce_sum(ns_loss_mask) + + return tf.reshape(reduced_masked_lm_loss + reduced_masked_ns_loss, (1,)) + + +class TFBertEmbeddings(keras.layers.Layer): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config: BertConfig, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.hidden_size = config.hidden_size + self.max_position_embeddings = config.max_position_embeddings + self.initializer_range = config.initializer_range + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def build(self, input_shape=None): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.config.vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("token_type_embeddings"): + self.token_type_embeddings = self.add_weight( + name="embeddings", + shape=[self.config.type_vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + def call( + self, + input_ids: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + token_type_ids: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + past_key_values_length=0, + training: bool = False, + ) -> tf.Tensor: + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + if input_ids is None and inputs_embeds is None: + raise ValueError("Need to provide either `input_ids` or `input_embeds`.") + + if input_ids is not None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + if position_ids is None: + position_ids = tf.expand_dims( + tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0 + ) + + position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) + token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) + final_embeddings = inputs_embeds + position_embeds + token_type_embeds + final_embeddings = self.LayerNorm(inputs=final_embeddings) + final_embeddings = self.dropout(inputs=final_embeddings, training=training) + + return final_embeddings + + +class TFBertSelfAttention(keras.layers.Layer): + def __init__(self, config: BertConfig, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number " + f"of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + + self.query = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + + self.is_decoder = config.is_decoder + self.config = config + + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> tuple[tf.Tensor]: + batch_size = shape_list(hidden_states)[0] + mixed_query_layer = self.query(inputs=hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) + else: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # (batch size, num_heads, seq_len_q, seq_len_k) + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) + attention_scores = tf.divide(attention_scores, dk) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in TFBertModel call() function) + attention_scores = tf.add(attention_scores, attention_mask) + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(inputs=attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = tf.multiply(attention_probs, head_mask) + + attention_output = tf.matmul(attention_probs, value_layer) + attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, all_head_size) + attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + + +class TFBertSelfOutput(keras.layers.Layer): + def __init__(self, config: BertConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +class TFBertAttention(keras.layers.Layer): + def __init__(self, config: BertConfig, **kwargs): + super().__init__(**kwargs) + + self.self_attention = TFBertSelfAttention(config, name="self") + self.dense_output = TFBertSelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + input_tensor: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> tuple[tf.Tensor]: + self_outputs = self.self_attention( + hidden_states=input_tensor, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=input_tensor, training=training + ) + # add attentions (possibly with past_key_value) if we output them + outputs = (attention_output,) + self_outputs[1:] + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attention", None) is not None: + with tf.name_scope(self.self_attention.name): + self.self_attention.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +class TFBertIntermediate(keras.layers.Layer): + def __init__(self, config: BertConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFBertOutput(keras.layers.Layer): + def __init__(self, config: BertConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +class TFBertLayer(keras.layers.Layer): + def __init__(self, config: BertConfig, **kwargs): + super().__init__(**kwargs) + + self.attention = TFBertAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFBertAttention(config, name="crossattention") + self.intermediate = TFBertIntermediate(config, name="intermediate") + self.bert_output = TFBertOutput(config, name="output") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor | None, + encoder_attention_mask: tf.Tensor | None, + past_key_value: tuple[tf.Tensor] | None, + output_attentions: bool, + training: bool = False, + ) -> tuple[tf.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + input_tensor=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=self_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + input_tensor=attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + intermediate_output = self.intermediate(hidden_states=attention_output) + layer_output = self.bert_output( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "bert_output", None) is not None: + with tf.name_scope(self.bert_output.name): + self.bert_output.build(None) + if getattr(self, "crossattention", None) is not None: + with tf.name_scope(self.crossattention.name): + self.crossattention.build(None) + + +class TFBertEncoder(keras.layers.Layer): + def __init__(self, config: BertConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.layer = [TFBertLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor | None, + encoder_attention_mask: tf.Tensor | None, + past_key_values: tuple[tuple[tf.Tensor]] | None, + use_cache: bool | None, + output_attentions: bool, + output_hidden_states: bool, + return_dict: bool, + training: bool = False, + ) -> TFBaseModelOutputWithPastAndCrossAttentions | tuple[tf.Tensor]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + past_key_value = past_key_values[i] if past_key_values is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if self.config.add_cross_attention and encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None + ) + + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFBertPooler(keras.layers.Layer): + def __init__(self, config: BertConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFBertPredictionHeadTransform(keras.layers.Layer): + def __init__(self, config: BertConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + name="dense", + ) + + if isinstance(config.hidden_act, str): + self.transform_act_fn = get_tf_activation(config.hidden_act) + else: + self.transform_act_fn = config.hidden_act + + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(inputs=hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +class TFBertLMPredictionHead(keras.layers.Layer): + def __init__(self, config: BertConfig, input_embeddings: keras.layers.Layer, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.hidden_size = config.hidden_size + + self.transform = TFBertPredictionHeadTransform(config, name="transform") + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.input_embeddings = input_embeddings + + def build(self, input_shape=None): + self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") + + if self.built: + return + self.built = True + if getattr(self, "transform", None) is not None: + with tf.name_scope(self.transform.name): + self.transform.build(None) + + def get_output_embeddings(self) -> keras.layers.Layer: + return self.input_embeddings + + def set_output_embeddings(self, value: tf.Variable): + self.input_embeddings.weight = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self) -> dict[str, tf.Variable]: + return {"bias": self.bias} + + def set_bias(self, value: tf.Variable): + self.bias = value["bias"] + self.config.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.transform(hidden_states=hidden_states) + seq_length = shape_list(hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) + hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) + + return hidden_states + + +class TFBertMLMHead(keras.layers.Layer): + def __init__(self, config: BertConfig, input_embeddings: keras.layers.Layer, **kwargs): + super().__init__(**kwargs) + + self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions") + + def call(self, sequence_output: tf.Tensor) -> tf.Tensor: + prediction_scores = self.predictions(hidden_states=sequence_output) + + return prediction_scores + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "predictions", None) is not None: + with tf.name_scope(self.predictions.name): + self.predictions.build(None) + + +class TFBertNSPHead(keras.layers.Layer): + def __init__(self, config: BertConfig, **kwargs): + super().__init__(**kwargs) + + self.seq_relationship = keras.layers.Dense( + units=2, + kernel_initializer=get_initializer(config.initializer_range), + name="seq_relationship", + ) + self.config = config + + def call(self, pooled_output: tf.Tensor) -> tf.Tensor: + seq_relationship_score = self.seq_relationship(inputs=pooled_output) + + return seq_relationship_score + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "seq_relationship", None) is not None: + with tf.name_scope(self.seq_relationship.name): + self.seq_relationship.build([None, None, self.config.hidden_size]) + + +@keras_serializable +class TFBertMainLayer(keras.layers.Layer): + config_class = BertConfig + + def __init__(self, config: BertConfig, add_pooling_layer: bool = True, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.is_decoder = config.is_decoder + + self.embeddings = TFBertEmbeddings(config, name="embeddings") + self.encoder = TFBertEncoder(config, name="encoder") + self.pooler = TFBertPooler(config, name="pooler") if add_pooling_layer else None + + def get_input_embeddings(self) -> keras.layers.Layer: + return self.embeddings + + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: + if not self.config.is_decoder: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_key_values_length = 0 + past_key_values = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(past_key_values[0][0])[-2] + + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + training=training, + ) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask_shape = shape_list(attention_mask) + + mask_seq_length = seq_length + past_key_values_length + # Copied from `modeling_tf_t5.py` + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + if past_key_values[0] is not None: + # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] + else: + extended_attention_mask = tf.reshape( + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.is_decoder and encoder_attention_mask is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None + + if not return_dict: + return ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + return TFBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + + +class TFBertPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + + +@dataclass +class TFBertForPreTrainingOutput(ModelOutput): + """ + Output type of [`TFBertForPreTraining`]. + + Args: + prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (`tf.Tensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + prediction_logits: tf.Tensor | None = None + seq_relationship_logits: tf.Tensor | None = None + hidden_states: tuple[tf.Tensor] | tf.Tensor | None = None + attentions: tuple[tf.Tensor] | tf.Tensor | None = None + + +BERT_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`BertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False``): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, +) +class TFBertModel(TFBertPreTrainedModel): + def __init__(self, config: BertConfig, add_pooling_layer: bool = True, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.bert = TFBertMainLayer(config, add_pooling_layer, name="bert") + + @unpack_inputs + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + """ + outputs = self.bert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "bert", None) is not None: + with tf.name_scope(self.bert.name): + self.bert.build(None) + + +@add_start_docstrings( + """ +Bert Model with two heads on top as done during the pretraining: + a `masked language modeling` head and a `next sentence prediction (classification)` head. + """, + BERT_START_DOCSTRING, +) +class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [ + r"position_ids", + r"cls.predictions.decoder.weight", + r"cls.predictions.decoder.bias", + ] + + def __init__(self, config: BertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.bert = TFBertMainLayer(config, name="bert") + self.nsp = TFBertNSPHead(config, name="nsp___cls") + self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls") + + def get_lm_head(self) -> keras.layers.Layer: + return self.mlm.predictions + + def get_prefix_bias_name(self) -> str: + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name + + @unpack_inputs + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + next_sentence_label: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFBertForPreTrainingOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + next_sentence_label (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring) Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + kwargs (`dict[str, any]`, *optional*, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + + Return: + + Examples: + + ```python + >>> import tensorflow as tf + >>> from transformers import AutoTokenizer, TFBertForPreTraining + + >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + >>> model = TFBertForPreTraining.from_pretrained("google-bert/bert-base-uncased") + >>> input_ids = tokenizer("Hello, my dog is cute", add_special_tokens=True, return_tensors="tf") + >>> # Batch size 1 + + >>> outputs = model(input_ids) + >>> prediction_logits, seq_relationship_logits = outputs[:2] + ```""" + outputs = self.bert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output, pooled_output = outputs[:2] + prediction_scores = self.mlm(sequence_output=sequence_output, training=training) + seq_relationship_score = self.nsp(pooled_output=pooled_output) + total_loss = None + + if labels is not None and next_sentence_label is not None: + d_labels = {"labels": labels} + d_labels["next_sentence_label"] = next_sentence_label + total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score)) + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return TFBertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "bert", None) is not None: + with tf.name_scope(self.bert.name): + self.bert.build(None) + if getattr(self, "nsp", None) is not None: + with tf.name_scope(self.nsp.name): + self.nsp.build(None) + if getattr(self, "mlm", None) is not None: + with tf.name_scope(self.mlm.name): + self.mlm.build(None) + + +@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING) +class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [ + r"pooler", + r"cls.seq_relationship", + r"cls.predictions.decoder.weight", + r"nsp___cls", + ] + + def __init__(self, config: BertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + if config.is_decoder: + logger.warning( + "If you want to use `TFBertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") + self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls") + + def get_lm_head(self) -> keras.layers.Layer: + return self.mlm.predictions + + def get_prefix_bias_name(self) -> str: + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name + + @unpack_inputs + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="'paris'", + expected_loss=0.88, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFMaskedLMOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + outputs = self.bert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + prediction_scores = self.mlm(sequence_output=sequence_output, training=training) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "bert", None) is not None: + with tf.name_scope(self.bert.name): + self.bert.build(None) + if getattr(self, "mlm", None) is not None: + with tf.name_scope(self.mlm.name): + self.mlm.build(None) + + +class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [ + r"pooler", + r"cls.seq_relationship", + r"cls.predictions.decoder.weight", + r"nsp___cls", + ] + + def __init__(self, config: BertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + if not config.is_decoder: + logger.warning("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") + self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls") + + def get_lm_head(self) -> keras.layers.Layer: + return self.mlm.predictions + + def get_prefix_bias_name(self) -> str: + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = tf.ones(input_shape) + + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} + + @unpack_inputs + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFCausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + **kwargs, + ) -> TFCausalLMOutputWithCrossAttentions | tuple[tf.Tensor]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + """ + outputs = self.bert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.mlm(sequence_output=sequence_output, training=training) + loss = None + + if labels is not None: + # shift labels to the left and cut last logit token + shifted_logits = logits[:, :-1] + labels = labels[:, 1:] + loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "bert", None) is not None: + with tf.name_scope(self.bert.name): + self.bert.build(None) + if getattr(self, "mlm", None) is not None: + with tf.name_scope(self.mlm.name): + self.mlm.build(None) + + +@add_start_docstrings( + """Bert Model with a `next sentence prediction (classification)` head on top.""", + BERT_START_DOCSTRING, +) +class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredictionLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"cls.predictions"] + + def __init__(self, config: BertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.bert = TFBertMainLayer(config, name="bert") + self.nsp = TFBertNSPHead(config, name="nsp___cls") + + @unpack_inputs + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + next_sentence_label: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFNextSentencePredictorOutput | tuple[tf.Tensor]: + r""" + Return: + + Examples: + + ```python + >>> import tensorflow as tf + >>> from transformers import AutoTokenizer, TFBertForNextSentencePrediction + + >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + >>> model = TFBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="tf") + + >>> logits = model(encoding["input_ids"], token_type_ids=encoding["token_type_ids"])[0] + >>> assert logits[0][0] < logits[0][1] # the next sentence was random + ```""" + outputs = self.bert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + pooled_output = outputs[1] + seq_relationship_scores = self.nsp(pooled_output=pooled_output) + next_sentence_loss = ( + None + if next_sentence_label is None + else self.hf_compute_loss(labels=next_sentence_label, logits=seq_relationship_scores) + ) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return TFNextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "bert", None) is not None: + with tf.name_scope(self.bert.name): + self.bert.build(None) + if getattr(self, "nsp", None) is not None: + with tf.name_scope(self.nsp.name): + self.nsp.build(None) + + +@add_start_docstrings( + """ + Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + BERT_START_DOCSTRING, +) +class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config: BertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.bert = TFBertMainLayer(config, name="bert") + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = keras.layers.Dropout(rate=classifier_dropout) + self.classifier = keras.layers.Dense( + units=config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="classifier", + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs = self.bert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(inputs=pooled_output, training=training) + logits = self.classifier(inputs=pooled_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "bert", None) is not None: + with tf.name_scope(self.bert.name): + self.bert.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + BERT_START_DOCSTRING, +) +class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config: BertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.bert = TFBertMainLayer(config, name="bert") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.classifier = keras.layers.Dense( + units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) + """ + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None + flat_attention_mask = ( + tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None + ) + flat_token_type_ids = ( + tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None + ) + flat_position_ids = ( + tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None + ) + flat_inputs_embeds = ( + tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None + else None + ) + outputs = self.bert( + input_ids=flat_input_ids, + attention_mask=flat_attention_mask, + token_type_ids=flat_token_type_ids, + position_ids=flat_position_ids, + head_mask=head_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(inputs=pooled_output, training=training) + logits = self.classifier(inputs=pooled_output) + reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "bert", None) is not None: + with tf.name_scope(self.bert.name): + self.bert.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + BERT_START_DOCSTRING, +) +class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [ + r"pooler", + r"mlm___cls", + r"nsp___cls", + r"cls.predictions", + r"cls.seq_relationship", + ] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config: BertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = keras.layers.Dropout(rate=classifier_dropout) + self.classifier = keras.layers.Dense( + units=config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="classifier", + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION, + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT, + expected_loss=_TOKEN_CLASS_EXPECTED_LOSS, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFTokenClassifierOutput | tuple[tf.Tensor]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.bert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(inputs=sequence_output, training=training) + logits = self.classifier(inputs=sequence_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "bert", None) is not None: + with tf.name_scope(self.bert.name): + self.bert.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BERT_START_DOCSTRING, +) +class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [ + r"pooler", + r"mlm___cls", + r"nsp___cls", + r"cls.predictions", + r"cls.seq_relationship", + ] + + def __init__(self, config: BertConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + + self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") + self.qa_outputs = keras.layers.Dense( + units=config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="qa_outputs", + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_QA, + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + qa_target_start_index=_QA_TARGET_START_INDEX, + qa_target_end_index=_QA_TARGET_END_INDEX, + expected_output=_QA_EXPECTED_OUTPUT, + expected_loss=_QA_EXPECTED_LOSS, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + start_positions: np.ndarray | tf.Tensor | None = None, + end_positions: np.ndarray | tf.Tensor | None = None, + training: bool | None = False, + ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]: + r""" + start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + outputs = self.bert( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.qa_outputs(inputs=sequence_output) + start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1) + start_logits = tf.squeeze(input=start_logits, axis=-1) + end_logits = tf.squeeze(input=end_logits, axis=-1) + loss = None + + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "bert", None) is not None: + with tf.name_scope(self.bert.name): + self.bert.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.config.hidden_size]) + + +__all__ = [ + "TFBertEmbeddings", + "TFBertForMaskedLM", + "TFBertForMultipleChoice", + "TFBertForNextSentencePrediction", + "TFBertForPreTraining", + "TFBertForQuestionAnswering", + "TFBertForSequenceClassification", + "TFBertForTokenClassification", + "TFBertLMHeadModel", + "TFBertMainLayer", + "TFBertModel", + "TFBertPreTrainedModel", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/tokenization_bert.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/tokenization_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..23cda58bfe723d7674a51e6c3f478f207464f639 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/tokenization_bert.py @@ -0,0 +1,478 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Bert.""" + +import collections +import os +import unicodedata +from typing import Optional + +from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class BertTokenizer(PreTrainedTokenizer): + r""" + Construct a BERT tokenizer. Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + clean_up_tokenization_spaces=True, + **kwargs, + ): + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text, split_special_tokens=False): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) + or (cp >= 0x20000 and cp <= 0x2A6DF) + or (cp >= 0x2A700 and cp <= 0x2B73F) + or (cp >= 0x2B740 and cp <= 0x2B81F) + or (cp >= 0x2B820 and cp <= 0x2CEAF) + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) + ): + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` will return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +__all__ = ["BasicTokenizer", "BertTokenizer", "WordpieceTokenizer"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/tokenization_bert_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/tokenization_bert_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..2cdc6129881be3f88a32ab1ae653c44db2049ece --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/tokenization_bert_fast.py @@ -0,0 +1,146 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization classes for Bert.""" + +import json +from typing import Optional + +from tokenizers import normalizers + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_bert import BertTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} + + +class BertTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" BERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + clean_text (`bool`, *optional*, defaults to `True`): + Whether or not to clean the text before tokenization by removing any control characters and replacing all + whitespaces by the classic one. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this + issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + wordpieces_prefix (`str`, *optional*, defaults to `"##"`): + The prefix for subwords. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = BertTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) + if ( + normalizer_state.get("lowercase", do_lower_case) != do_lower_case + or normalizer_state.get("strip_accents", strip_accents) != strip_accents + or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars + ): + normalizer_class = getattr(normalizers, normalizer_state.pop("type")) + normalizer_state["lowercase"] = do_lower_case + normalizer_state["strip_accents"] = strip_accents + normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars + self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state) + + self.do_lower_case = do_lower_case + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + if token_ids_1 is not None: + output += token_ids_1 + [self.sep_token_id] + + return output + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["BertTokenizerFast"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/tokenization_bert_tf.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/tokenization_bert_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..c8fca52c4cbf39d033296429e2034d7b9dd7904e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/tokenization_bert_tf.py @@ -0,0 +1,259 @@ +import os +from typing import Optional, Union + +import tensorflow as tf +from tensorflow_text import BertTokenizer as BertTokenizerLayer +from tensorflow_text import FastBertTokenizer, ShrinkLongestTrimmer, case_fold_utf8, combine_segments, pad_model_inputs + +from ...modeling_tf_utils import keras +from ...utils.import_utils import requires +from .tokenization_bert import BertTokenizer + + +@requires(backends=("tf", "tensorflow_text")) +class TFBertTokenizer(keras.layers.Layer): + """ + This is an in-graph tokenizer for BERT. It should be initialized similarly to other tokenizers, using the + `from_pretrained()` method. It can also be initialized with the `from_tokenizer()` method, which imports settings + from an existing standard tokenizer object. + + In-graph tokenizers, unlike other Hugging Face tokenizers, are actually Keras layers and are designed to be run + when the model is called, rather than during preprocessing. As a result, they have somewhat more limited options + than standard tokenizer classes. They are most useful when you want to create an end-to-end model that goes + straight from `tf.string` inputs to outputs. + + Args: + vocab_list (`list`): + List containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + cls_token_id (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + sep_token_id (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token_id (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + padding (`str`, defaults to `"longest"`): + The type of padding to use. Can be either `"longest"`, to pad only up to the longest sample in the batch, + or `"max_length", to pad all inputs to the maximum length supported by the tokenizer. + truncation (`bool`, *optional*, defaults to `True`): + Whether to truncate the sequence to the maximum length. + max_length (`int`, *optional*, defaults to `512`): + The maximum length of the sequence, used for padding (if `padding` is "max_length") and/or truncation (if + `truncation` is `True`). + pad_to_multiple_of (`int`, *optional*, defaults to `None`): + If set, the sequence will be padded to a multiple of this value. + return_token_type_ids (`bool`, *optional*, defaults to `True`): + Whether to return token_type_ids. + return_attention_mask (`bool`, *optional*, defaults to `True`): + Whether to return the attention_mask. + use_fast_bert_tokenizer (`bool`, *optional*, defaults to `True`): + If True, will use the FastBertTokenizer class from Tensorflow Text. If False, will use the BertTokenizer + class instead. BertTokenizer supports some additional options, but is slower and cannot be exported to + TFLite. + """ + + def __init__( + self, + vocab_list: list, + do_lower_case: bool, + cls_token_id: Optional[int] = None, + sep_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + padding: str = "longest", + truncation: bool = True, + max_length: int = 512, + pad_to_multiple_of: Optional[int] = None, + return_token_type_ids: bool = True, + return_attention_mask: bool = True, + use_fast_bert_tokenizer: bool = True, + **tokenizer_kwargs, + ): + super().__init__() + if use_fast_bert_tokenizer: + self.tf_tokenizer = FastBertTokenizer( + vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case, **tokenizer_kwargs + ) + else: + lookup_table = tf.lookup.StaticVocabularyTable( + tf.lookup.KeyValueTensorInitializer( + keys=vocab_list, + key_dtype=tf.string, + values=tf.range(tf.size(vocab_list, out_type=tf.int64), dtype=tf.int64), + value_dtype=tf.int64, + ), + num_oov_buckets=1, + ) + self.tf_tokenizer = BertTokenizerLayer( + lookup_table, token_out_type=tf.int64, lower_case=do_lower_case, **tokenizer_kwargs + ) + + self.vocab_list = vocab_list + self.do_lower_case = do_lower_case + self.cls_token_id = vocab_list.index("[CLS]") if cls_token_id is None else cls_token_id + self.sep_token_id = vocab_list.index("[SEP]") if sep_token_id is None else sep_token_id + self.pad_token_id = vocab_list.index("[PAD]") if pad_token_id is None else pad_token_id + self.paired_trimmer = ShrinkLongestTrimmer(max_length - 3, axis=1) # Allow room for special tokens + self.max_length = max_length + self.padding = padding + self.truncation = truncation + self.pad_to_multiple_of = pad_to_multiple_of + self.return_token_type_ids = return_token_type_ids + self.return_attention_mask = return_attention_mask + + @classmethod + def from_tokenizer(cls, tokenizer: "PreTrainedTokenizerBase", **kwargs): # noqa: F821 + """ + Initialize a `TFBertTokenizer` from an existing `Tokenizer`. + + Args: + tokenizer (`PreTrainedTokenizerBase`): + The tokenizer to use to initialize the `TFBertTokenizer`. + + Examples: + + ```python + from transformers import AutoTokenizer, TFBertTokenizer + + tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + tf_tokenizer = TFBertTokenizer.from_tokenizer(tokenizer) + ``` + """ + do_lower_case = kwargs.pop("do_lower_case", None) + do_lower_case = tokenizer.do_lower_case if do_lower_case is None else do_lower_case + cls_token_id = kwargs.pop("cls_token_id", None) + cls_token_id = tokenizer.cls_token_id if cls_token_id is None else cls_token_id + sep_token_id = kwargs.pop("sep_token_id", None) + sep_token_id = tokenizer.sep_token_id if sep_token_id is None else sep_token_id + pad_token_id = kwargs.pop("pad_token_id", None) + pad_token_id = tokenizer.pad_token_id if pad_token_id is None else pad_token_id + + vocab = tokenizer.get_vocab() + vocab = sorted(vocab.items(), key=lambda x: x[1]) + vocab_list = [entry[0] for entry in vocab] + return cls( + vocab_list=vocab_list, + do_lower_case=do_lower_case, + cls_token_id=cls_token_id, + sep_token_id=sep_token_id, + pad_token_id=pad_token_id, + **kwargs, + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs): + """ + Instantiate a `TFBertTokenizer` from a pre-trained tokenizer. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + The name or path to the pre-trained tokenizer. + + Examples: + + ```python + from transformers import TFBertTokenizer + + tf_tokenizer = TFBertTokenizer.from_pretrained("google-bert/bert-base-uncased") + ``` + """ + try: + tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) + except: # noqa: E722 + from .tokenization_bert_fast import BertTokenizerFast + + tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) + return cls.from_tokenizer(tokenizer, **kwargs) + + def unpaired_tokenize(self, texts): + if self.do_lower_case: + texts = case_fold_utf8(texts) + tokens = self.tf_tokenizer.tokenize(texts) + return tokens.merge_dims(1, -1) + + def call( + self, + text, + text_pair=None, + padding=None, + truncation=None, + max_length=None, + pad_to_multiple_of=None, + return_token_type_ids=None, + return_attention_mask=None, + ): + if padding is None: + padding = self.padding + if padding not in ("longest", "max_length"): + raise ValueError("Padding must be either 'longest' or 'max_length'!") + if max_length is not None and text_pair is not None: + # Because we have to instantiate a Trimmer to do it properly + raise ValueError("max_length cannot be overridden at call time when truncating paired texts!") + if max_length is None: + max_length = self.max_length + if truncation is None: + truncation = self.truncation + if pad_to_multiple_of is None: + pad_to_multiple_of = self.pad_to_multiple_of + if return_token_type_ids is None: + return_token_type_ids = self.return_token_type_ids + if return_attention_mask is None: + return_attention_mask = self.return_attention_mask + if not isinstance(text, tf.Tensor): + text = tf.convert_to_tensor(text) + if text_pair is not None and not isinstance(text_pair, tf.Tensor): + text_pair = tf.convert_to_tensor(text_pair) + if text_pair is not None: + if text.shape.rank > 1: + raise ValueError("text argument should not be multidimensional when a text pair is supplied!") + if text_pair.shape.rank > 1: + raise ValueError("text_pair should not be multidimensional!") + if text.shape.rank == 2: + text, text_pair = text[:, 0], text[:, 1] + text = self.unpaired_tokenize(text) + if text_pair is None: # Unpaired text + if truncation: + text = text[:, : max_length - 2] # Allow room for special tokens + input_ids, token_type_ids = combine_segments( + (text,), start_of_sequence_id=self.cls_token_id, end_of_segment_id=self.sep_token_id + ) + else: # Paired text + text_pair = self.unpaired_tokenize(text_pair) + if truncation: + text, text_pair = self.paired_trimmer.trim([text, text_pair]) + input_ids, token_type_ids = combine_segments( + (text, text_pair), start_of_sequence_id=self.cls_token_id, end_of_segment_id=self.sep_token_id + ) + if padding == "longest": + pad_length = input_ids.bounding_shape(axis=1) + if pad_to_multiple_of is not None: + # No ceiling division in tensorflow, so we negate floordiv instead + pad_length = pad_to_multiple_of * (-tf.math.floordiv(-pad_length, pad_to_multiple_of)) + else: + pad_length = max_length + + input_ids, attention_mask = pad_model_inputs(input_ids, max_seq_length=pad_length, pad_value=self.pad_token_id) + output = {"input_ids": input_ids} + if return_attention_mask: + output["attention_mask"] = attention_mask + if return_token_type_ids: + token_type_ids, _ = pad_model_inputs( + token_type_ids, max_seq_length=pad_length, pad_value=self.pad_token_id + ) + output["token_type_ids"] = token_type_ids + return output + + def get_config(self): + return { + "vocab_list": self.vocab_list, + "do_lower_case": self.do_lower_case, + "cls_token_id": self.cls_token_id, + "sep_token_id": self.sep_token_id, + "pad_token_id": self.pad_token_id, + } + + +__all__ = ["TFBertTokenizer"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bertweet/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bertweet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..432622f1595d1a0d8bb1b3c9b9774b7d1e387d3e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bertweet/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .tokenization_bertweet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bertweet/tokenization_bertweet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bertweet/tokenization_bertweet.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce1a3182bf9d5b3f960b5a211544612ab3129c3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bertweet/tokenization_bertweet.py @@ -0,0 +1,769 @@ +# coding=utf-8 +# Copyright (c) 2020, VinAI Research and the HuggingFace Inc. team. +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for BERTweet""" + +import html +import os +import re +from shutil import copyfile +from typing import Optional + +import regex + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.txt", + "merges_file": "bpe.codes", +} + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + + pairs = set(pairs) + return pairs + + +class BertweetTokenizer(PreTrainedTokenizer): + """ + Constructs a BERTweet tokenizer, using Byte-Pair-Encoding. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + normalization (`bool`, *optional*, defaults to `False`): + Whether or not to apply a normalization preprocess. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + merges_file, + normalization=False, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + **kwargs, + ): + try: + from emoji import demojize + + self.demojizer = demojize + except ImportError: + logger.warning( + "emoji is not installed, thus not converting emoticons or emojis into text. Install emoji: pip3" + " install emoji==0.6.0" + ) + self.demojizer = None + + self.vocab_file = vocab_file + self.merges_file = merges_file + + self.encoder = {} + self.encoder[str(bos_token)] = 0 + self.encoder[str(pad_token)] = 1 + self.encoder[str(eos_token)] = 2 + self.encoder[str(unk_token)] = 3 + + self.add_from_file(vocab_file) + + self.decoder = {v: k for k, v in self.encoder.items()} + + with open(merges_file, encoding="utf-8") as merges_handle: + merges = merges_handle.read().split("\n")[:-1] + merges = [tuple(merge.split()[:-1]) for merge in merges] + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {} + + self.normalization = normalization + self.tweetPreprocessor = TweetTokenizer() + self.special_puncts = {"’": "'", "…": "..."} + + super().__init__( + normalization=normalization, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + mask_token=mask_token, + **kwargs, + ) + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERTweet sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. BERTweet does + not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + @property + def vocab_size(self): + return len(self.encoder) + + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + word = tuple(list(word[:-1]) + [word[-1] + ""]) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = "@@ ".join(word) + word = word[:-4] + self.cache[token] = word + return word + + def _tokenize(self, text): + """Tokenize a string.""" + if self.normalization: # Perform Tweet normalization before performing BPE + text = self.normalizeTweet(text) + + split_tokens = [] + words = re.findall(r"\S+\n?", text) + for token in words: + split_tokens.extend(list(self.bpe(token).split(" "))) + return split_tokens + + def normalizeTweet(self, tweet): + """ + Normalize a raw Tweet + """ + for punct in self.special_puncts: + tweet = tweet.replace(punct, self.special_puncts[punct]) + + tokens = self.tweetPreprocessor.tokenize(tweet) + normTweet = " ".join([self.normalizeToken(token) for token in tokens]) + + normTweet = ( + normTweet.replace("cannot ", "can not ") + .replace("n't ", " n't ") + .replace("n 't ", " n't ") + .replace("ca n't", "can't") + .replace("ai n't", "ain't") + ) + normTweet = ( + normTweet.replace("'m ", " 'm ") + .replace("'re ", " 're ") + .replace("'s ", " 's ") + .replace("'ll ", " 'll ") + .replace("'d ", " 'd ") + .replace("'ve ", " 've ") + ) + normTweet = ( + normTweet.replace(" p . m .", " p.m.") + .replace(" p . m ", " p.m ") + .replace(" a . m .", " a.m.") + .replace(" a . m ", " a.m ") + ) + + return " ".join(normTweet.split()) + + def normalizeToken(self, token): + """ + Normalize tokens in a Tweet + """ + lowercased_token = token.lower() + if token.startswith("@"): + return "@USER" + elif lowercased_token.startswith("http") or lowercased_token.startswith("www"): + return "HTTPURL" + elif len(token) == 1: + if token in self.special_puncts: + return self.special_puncts[token] + if self.demojizer is not None: + return self.demojizer(token) + else: + return token + else: + return token + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace("@@ ", "").strip() + return out_string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + out_merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + if os.path.abspath(self.merges_file) != os.path.abspath(out_merge_file): + copyfile(self.merges_file, out_merge_file) + + return out_vocab_file, out_merge_file + + # def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): + # filtered_tokens = ' '.join(self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)) + # tokens_generated_so_far = re.sub('(@@ )', '', string=filtered_tokens) + # tokens_generated_so_far = re.sub('(@@ ?$)', '', string=tokens_generated_so_far) + # return ''.join(tokens_generated_so_far) + + def add_from_file(self, f): + """ + Loads a pre-existing dictionary from a text file and adds its symbols to this instance. + """ + if isinstance(f, str): + try: + with open(f, "r", encoding="utf-8") as fd: + self.add_from_file(fd) + except FileNotFoundError as fnfe: + raise fnfe + except UnicodeError: + raise Exception(f"Incorrect encoding detected in {f}, please rebuild the dataset") + return + + lines = f.readlines() + for lineTmp in lines: + line = lineTmp.strip() + idx = line.rfind(" ") + if idx == -1: + raise ValueError("Incorrect dictionary format, expected ' '") + word = line[:idx] + self.encoder[word] = len(self.encoder) + + +# Natural Language Toolkit: Twitter Tokenizer +# +# Copyright (C) 2001-2020 NLTK Project +# Author: Christopher Potts +# Ewan Klein (modifications) +# Pierpaolo Pantone <> (modifications) +# URL: http://nltk.org/ +# For license information, see LICENSE.TXT +# + + +""" +Twitter-aware tokenizer, designed to be flexible and easy to adapt to new domains and tasks. The basic logic is this: + +1. The tuple regex_strings defines a list of regular expression strings. + +2. The regex_strings strings are put, in order, into a compiled regular expression object called word_re. + +3. The tokenization is done by word_re.findall(s), where s is the user-supplied string, inside the tokenize() method of + the class Tokenizer. + +4. When instantiating Tokenizer objects, there is a single option: preserve_case. By default, it is set to True. If it + is set to False, then the tokenizer will lowercase everything except for emoticons. + +""" + + +###################################################################### +# +# import regex # https://github.com/nltk/nltk/issues/2409 +# import html +# +###################################################################### +# The following strings are components in the regular expression +# that is used for tokenizing. It's important that phone_number +# appears first in the final regex (since it can contain whitespace). +# It also could matter that tags comes after emoticons, due to the +# possibility of having text like +# +# <:| and some text >:) +# +# Most importantly, the final element should always be last, since it +# does a last ditch whitespace-based tokenization of whatever is left. + +# ToDo: Update with http://en.wikipedia.org/wiki/List_of_emoticons ? + +# This particular element is used in a couple ways, so we define it +# with a name: +# docstyle-ignore +EMOTICONS = r""" + (?: + [<>]? + [:;=8] # eyes + [\-o\*\']? # optional nose + [\)\]\(\[dDpP/\:\}\{@\|\\] # mouth + | + [\)\]\(\[dDpP/\:\}\{@\|\\] # mouth + [\-o\*\']? # optional nose + [:;=8] # eyes + [<>]? + | + <3 # heart + )""" + +# URL pattern due to John Gruber, modified by Tom Winzig. See +# https://gist.github.com/winzig/8894715 +# docstyle-ignore +URLS = r""" # Capture 1: entire matched URL + (?: + https?: # URL protocol and colon + (?: + /{1,3} # 1-3 slashes + | # or + [a-z0-9%] # Single letter or digit or '%' + # (Trying not to match e.g. "URI::Escape") + ) + | # or + # looks like domain name followed by a slash: + [a-z0-9.\-]+[.] + (?:[a-z]{2,13}) + / + ) + (?: # One or more: + [^\s()<>{}\[\]]+ # Run of non-space, non-()<>{}[] + | # or + \([^\s()]*?\([^\s()]+\)[^\s()]*?\) # balanced parens, one level deep: (...(...)...) + | + \([^\s]+?\) # balanced parens, non-recursive: (...) + )+ + (?: # End with: + \([^\s()]*?\([^\s()]+\)[^\s()]*?\) # balanced parens, one level deep: (...(...)...) + | + \([^\s]+?\) # balanced parens, non-recursive: (...) + | # or + [^\s`!()\[\]{};:'".,<>?«»“”‘’] # not a space or one of these punct chars + ) + | # OR, the following to match naked domains: + (?: + (?\s]+>""", + # ASCII Arrows + r"""[\-]+>|<[\-]+""", + # Twitter username: + r"""(?:@[\w_]+)""", + # Twitter hashtags: + r"""(?:\#+[\w_]+[\w\'_\-]*[\w_]+)""", + # email addresses + r"""[\w.+-]+@[\w-]+\.(?:[\w-]\.?)+[\w-]""", + # docstyle-ignore + # Remaining word types: + r""" + (?:[^\W\d_](?:[^\W\d_]|['\-_])+[^\W\d_]) # Words with apostrophes or dashes. + | + (?:[+\-]?\d+[,/.:-]\d+[+\-]?) # Numbers, including fractions, decimals. + | + (?:[\w_]+) # Words without apostrophes or dashes. + | + (?:\.(?:\s*\.){1,}) # Ellipsis dots. + | + (?:\S) # Everything else that isn't whitespace. + """, +) + +###################################################################### +# This is the core tokenizing regex: + +WORD_RE = regex.compile(r"""(%s)""" % "|".join(REGEXPS), regex.VERBOSE | regex.I | regex.UNICODE) + +# WORD_RE performs poorly on these patterns: +HANG_RE = regex.compile(r"([^a-zA-Z0-9])\1{3,}") + +# The emoticon string gets its own regex so that we can preserve case for +# them as needed: +EMOTICON_RE = regex.compile(EMOTICONS, regex.VERBOSE | regex.I | regex.UNICODE) + +# These are for regularizing HTML entities to Unicode: +ENT_RE = regex.compile(r"&(#?(x?))([^&;\s]+);") + + +###################################################################### +# Functions for converting html entities +###################################################################### + + +def _str_to_unicode(text, encoding=None, errors="strict"): + if encoding is None: + encoding = "utf-8" + if isinstance(text, bytes): + return text.decode(encoding, errors) + return text + + +def _replace_html_entities(text, keep=(), remove_illegal=True, encoding="utf-8"): + """ + Remove entities from text by converting them to their corresponding unicode character. + + Args: + text: + A unicode string or a byte string encoded in the given *encoding* (which defaults to 'utf-8'). + keep (list): + List of entity names which should not be replaced. This supports both numeric entities (`&#nnnn;` and + `&#hhhh;`) and named entities (such as ` ` or `>`). + remove_illegal (bool): + If `True`, entities that can't be converted are removed. Otherwise, entities that can't be converted are + kept "as is". + + Returns: A unicode string with the entities removed. + + See https://github.com/scrapy/w3lib/blob/master/w3lib/html.py + + Examples: + + ```python + >>> from nltk.tokenize.casual import _replace_html_entities + + >>> _replace_html_entities(b"Price: £100") + 'Price: \\xa3100' + + >>> print(_replace_html_entities(b"Price: £100")) + Price: £100 + ```""" + + def _convert_entity(match): + entity_body = match.group(3) + if match.group(1): + try: + if match.group(2): + number = int(entity_body, 16) + else: + number = int(entity_body, 10) + # Numeric character references in the 80-9F range are typically + # interpreted by browsers as representing the characters mapped + # to bytes 80-9F in the Windows-1252 encoding. For more info + # see: https://en.wikipedia.org/wiki/ISO/IEC_8859-1#Similar_character_sets + if 0x80 <= number <= 0x9F: + return bytes((number,)).decode("cp1252") + except ValueError: + number = None + else: + if entity_body in keep: + return match.group(0) + else: + number = html.entities.name2codepoint.get(entity_body) + if number is not None: + try: + return chr(number) + except (ValueError, OverflowError): + pass + + return "" if remove_illegal else match.group(0) + + return ENT_RE.sub(_convert_entity, _str_to_unicode(text, encoding)) + + +###################################################################### + + +class TweetTokenizer: + r""" + Examples: + + ```python + >>> # Tokenizer for tweets. + >>> from nltk.tokenize import TweetTokenizer + + >>> tknzr = TweetTokenizer() + >>> s0 = "This is a cooool #dummysmiley: :-) :-P <3 and some arrows < > -> <--" + >>> tknzr.tokenize(s0) + ['This', 'is', 'a', 'cooool', '#dummysmiley', ':', ':-)', ':-P', '<3', 'and', 'some', 'arrows', '<', '>', '->', '<--'] + + >>> # Examples using *strip_handles* and *reduce_len parameters*: + >>> tknzr = TweetTokenizer(strip_handles=True, reduce_len=True) + >>> s1 = "@remy: This is waaaaayyyy too much for you!!!!!!" + >>> tknzr.tokenize(s1) + [':', 'This', 'is', 'waaayyy', 'too', 'much', 'for', 'you', '!', '!', '!'] + ```""" + + def __init__(self, preserve_case=True, reduce_len=False, strip_handles=False): + self.preserve_case = preserve_case + self.reduce_len = reduce_len + self.strip_handles = strip_handles + + def tokenize(self, text): + """ + Args: + text: str + + Returns: list(str) A tokenized list of strings; concatenating this list returns the original string if + `preserve_case=False` + """ + # Fix HTML character entities: + text = _replace_html_entities(text) + # Remove username handles + if self.strip_handles: + text = remove_handles(text) + # Normalize word lengthening + if self.reduce_len: + text = reduce_lengthening(text) + # Shorten problematic sequences of characters + safe_text = HANG_RE.sub(r"\1\1\1", text) + # Tokenize: + words = WORD_RE.findall(safe_text) + # Possibly alter the case, but avoid changing emoticons like :D into :d: + if not self.preserve_case: + words = [x if EMOTICON_RE.search(x) else x.lower() for x in words] + return words + + +###################################################################### +# Normalization Functions +###################################################################### + + +def reduce_lengthening(text): + """ + Replace repeated character sequences of length 3 or greater with sequences of length 3. + """ + pattern = regex.compile(r"(.)\1{2,}") + return pattern.sub(r"\1\1\1", text) + + +def remove_handles(text): + """ + Remove Twitter username handles from text. + """ + pattern = regex.compile( + r"(?>> from transformers import BioGptModel, BioGptConfig + + >>> # Initializing a BioGPT microsoft/biogpt style configuration + >>> configuration = BioGptConfig() + + >>> # Initializing a model from the microsoft/biogpt style configuration + >>> model = BioGptModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "biogpt" + + def __init__( + self, + vocab_size=42384, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=1024, + initializer_range=0.02, + layer_norm_eps=1e-12, + scale_embedding=True, + use_cache=True, + layerdrop=0.0, + activation_dropout=0.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.scale_embedding = scale_embedding + self.use_cache = use_cache + self.layerdrop = layerdrop + self.activation_dropout = activation_dropout + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +__all__ = ["BioGptConfig"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/biogpt/modeling_biogpt.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/biogpt/modeling_biogpt.py new file mode 100644 index 0000000000000000000000000000000000000000..7b9937420025f86dc74c73b12060d28d7beec627 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/biogpt/modeling_biogpt.py @@ -0,0 +1,967 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/biogpt/modular_biogpt.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_biogpt.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg +from .configuration_biogpt import BioGptConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +class BioGptLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # BIOGPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + + if position_ids is None: + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_key_values_length` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + return super().forward(position_ids + self.offset) + + +class BioGptScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class BioGptAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[BioGptConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class BioGptDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None): + super().__init__() + self.embed_dim = config.hidden_size + + self.self_attn = BioGptAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=config.attention_probs_dropout_prob, + is_decoder=True, + is_causal=True, + config=config, + layer_idx=layer_idx, + ) + self.dropout = config.hidden_dropout_prob + self.activation_fn = ACT2FN[config.hidden_act] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + past_key_values (`Cache`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + residual = hidden_states + + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + position_ids=position_ids, + cache_position=cache_position, + **kwargs, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class BioGptPreTrainedModel(PreTrainedModel): + config: BioGptConfig + base_model_prefix = "biogpt" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +@auto_docstring +class BioGptModel(BioGptPreTrainedModel): + def __init__(self, config: BioGptConfig): + super().__init__(config) + self.config = config + self.layerdrop = config.layerdrop + self.dropout = config.hidden_dropout_prob + self.embed_dim = config.hidden_size + self.padding_idx = config.pad_token_id + embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + self.embed_tokens = BioGptScaledWordEmbedding( + config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale + ) + self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim) + + self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.layer_norm = nn.LayerNorm(self.embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize past_key_values + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None: + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = past_key_values + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + ) + + # embed positions + if position_ids is None: + # position_ids = cache_position.unsqueeze(0) + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_seen_tokens` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids) + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = None + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_ids=position_ids, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = self.layer_norm(hidden_states) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + BioGPT Model with a `language modeling` head on top for CLM fine-tuning. + """ +) +class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["output_projection.weight"] + + def __init__(self, config): + super().__init__(config) + + self.biogpt = BioGptModel(config) + self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.output_projection + + def set_output_embeddings(self, new_embeddings): + self.output_projection = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.biogpt( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + sequence_output = outputs[0] + prediction_scores = self.output_projection(sequence_output) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@auto_docstring +class BioGptForTokenClassification(BioGptPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.biogpt = BioGptModel(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + else: + classifier_dropout = config.hidden_dropout_prob + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.biogpt( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The BioGpt Model transformer with a sequence classification head on top (linear layer). + + [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it is required to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """ +) +class BioGptForSequenceClassification(BioGptPreTrainedModel): + def __init__(self, config: BioGptConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.biogpt = BioGptModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + ) -> Union[tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.biogpt( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + hidden_states = transformer_outputs[0] + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.score(hidden_states[:, slice_indices, :]) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None: + sequence_length = -1 + else: + if input_ids is not None: + sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_length = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.biogpt.embed_tokens + + def set_input_embeddings(self, value): + self.biogpt.embed_tokens = value + + +__all__ = [ + "BioGptForCausalLM", + "BioGptForTokenClassification", + "BioGptForSequenceClassification", + "BioGptModel", + "BioGptPreTrainedModel", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/biogpt/modular_biogpt.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/biogpt/modular_biogpt.py new file mode 100644 index 0000000000000000000000000000000000000000..8d95b2a2d051a960715fa10edc9783f380a53696 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/biogpt/modular_biogpt.py @@ -0,0 +1,789 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BioGPT model.""" + +import math +from typing import Optional, Union + +import torch +import torch.nn as nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, +) +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + TransformersKwargs, + auto_docstring, + is_torch_flex_attn_available, + logger, +) +from ...utils.deprecation import deprecate_kwarg +from ..bart.modeling_bart import ( + BartAttention, + BartDecoderLayer, + BartScaledWordEmbedding, +) +from ..opt.modeling_opt import OPTLearnedPositionalEmbedding +from .configuration_biogpt import BioGptConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask + + +class BioGptLearnedPositionalEmbedding(OPTLearnedPositionalEmbedding): + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + super().forward(attention_mask, past_key_values_length, position_ids) + + +class BioGptScaledWordEmbedding(BartScaledWordEmbedding): + pass + + +class BioGptAttention(BartAttention): + pass + + +class BioGptDecoderLayer(BartDecoderLayer): + def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None): + super().__init__(config) + self.embed_dim = config.hidden_size + + self.self_attn = BioGptAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=config.attention_probs_dropout_prob, + is_decoder=True, + is_causal=True, + config=config, + layer_idx=layer_idx, + ) + self.dropout = config.hidden_dropout_prob + self.activation_fn = ACT2FN[config.hidden_act] + + self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim) + + del self.encoder_attn + del self.encoder_attn_layer_norm + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + past_key_values (`Cache`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + residual = hidden_states + + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + position_ids=position_ids, + cache_position=cache_position, + **kwargs, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class BioGptPreTrainedModel(PreTrainedModel): + config: BioGptConfig + base_model_prefix = "biogpt" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +@auto_docstring +class BioGptModel(BioGptPreTrainedModel): + def __init__(self, config: BioGptConfig): + super().__init__(config) + self.config = config + self.layerdrop = config.layerdrop + self.dropout = config.hidden_dropout_prob + self.embed_dim = config.hidden_size + self.padding_idx = config.pad_token_id + embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + self.embed_tokens = BioGptScaledWordEmbedding( + config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale + ) + self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim) + + self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.layer_norm = nn.LayerNorm(self.embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize past_key_values + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None: + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = past_key_values + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + ) + + # embed positions + if position_ids is None: + # position_ids = cache_position.unsqueeze(0) + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_seen_tokens` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids) + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = None + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_ids=position_ids, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = self.layer_norm(hidden_states) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@auto_docstring( + custom_intro=""" + BioGPT Model with a `language modeling` head on top for CLM fine-tuning. + """ +) +class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["output_projection.weight"] + + def __init__(self, config): + super().__init__(config) + + self.biogpt = BioGptModel(config) + self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.output_projection + + def set_output_embeddings(self, new_embeddings): + self.output_projection = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.biogpt( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + sequence_output = outputs[0] + prediction_scores = self.output_projection(sequence_output) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@auto_docstring +class BioGptForTokenClassification(BioGptPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.biogpt = BioGptModel(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + else: + classifier_dropout = config.hidden_dropout_prob + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.biogpt( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The BioGpt Model transformer with a sequence classification head on top (linear layer). + + [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it is required to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """ +) +class BioGptForSequenceClassification(BioGptPreTrainedModel): + def __init__(self, config: BioGptConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.biogpt = BioGptModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + ) -> Union[tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.biogpt( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + hidden_states = transformer_outputs[0] + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.score(hidden_states[:, slice_indices, :]) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None: + sequence_length = -1 + else: + if input_ids is not None: + sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_length = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.biogpt.embed_tokens + + def set_input_embeddings(self, value): + self.biogpt.embed_tokens = value + + +__all__ = [ + "BioGptForCausalLM", + "BioGptForTokenClassification", + "BioGptForSequenceClassification", + "BioGptModel", + "BioGptPreTrainedModel", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/biogpt/tokenization_biogpt.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/biogpt/tokenization_biogpt.py new file mode 100644 index 0000000000000000000000000000000000000000..f84403ca7ddc65586528da07e4efc9f8621a0e19 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/biogpt/tokenization_biogpt.py @@ -0,0 +1,331 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for BioGPT.""" + +import json +import os +from typing import Optional + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length + strings) + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class BioGptTokenizer(PreTrainedTokenizer): + """ + Construct an FAIRSEQ Transformer tokenizer. Moses tokenization followed by Byte-Pair Encoding. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Merges file. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + unk_token="", + bos_token="", + eos_token="", + sep_token="", + pad_token="", + **kwargs, + ): + try: + import sacremoses + except ImportError: + raise ImportError( + "You need to install sacremoses to use BioGptTokenizer. " + "See https://pypi.org/project/sacremoses/ for installation." + ) + + self.lang = "en" + self.sm = sacremoses + # cache of sm.MosesTokenizer instance + self.cache_moses_tokenizer = {} + self.cache_moses_detokenizer = {} + + """ Initialisation""" + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + merges = merges_handle.read().split("\n")[:-1] + merges = [tuple(merge.split()[:2]) for merge in merges] + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {} + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + unk_token=unk_token, + pad_token=pad_token, + **kwargs, + ) + + @property + def vocab_size(self): + """Returns vocab size""" + return len(self.encoder) + + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + def moses_tokenize(self, text, lang): + if lang not in self.cache_moses_tokenizer: + moses_tokenizer = self.sm.MosesTokenizer(lang=lang) + self.cache_moses_tokenizer[lang] = moses_tokenizer + return self.cache_moses_tokenizer[lang].tokenize( + text, aggressive_dash_splits=True, return_str=False, escape=True + ) + + def moses_detokenize(self, tokens, lang): + if lang not in self.cache_moses_detokenizer: + moses_detokenizer = self.sm.MosesDetokenizer(lang=lang) + self.cache_moses_detokenizer[lang] = moses_detokenizer + return self.cache_moses_detokenizer[lang].detokenize(tokens) + + def bpe(self, token): + word = tuple(token[:-1]) + (token[-1] + "",) + if token in self.cache: + return self.cache[token] + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + if word == "\n ": + word = "\n" + self.cache[token] = word + return word + + def _tokenize(self, text, bypass_tokenizer=False): + """Returns a tokenized string.""" + if bypass_tokenizer: + text = text.split() + else: + text = self.moses_tokenize(text, self.lang) + + split_tokens = [] + for token in text: + if token: + split_tokens.extend(list(self.bpe(token).split(" "))) + + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + # remove BPE + tokens = [t.replace(" ", "").replace("", " ") for t in tokens] + tokens = "".join(tokens).split() + # detokenize + text = self.moses_detokenize(tokens, self.lang) + return text + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BioGPT sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.sep_token_id] + token_ids_0 + sep = [self.sep_token_id] + return sep + token_ids_0 + sep + token_ids_1 + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + # no bos used in fairseq + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + return [1] + ([0] * len(token_ids_0)) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + def __getstate__(self): + state = self.__dict__.copy() + state["sm"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + try: + import sacremoses + except ImportError: + raise ImportError( + "You need to install sacremoses to use XLMTokenizer. " + "See https://pypi.org/project/sacremoses/ for installation." + ) + + self.sm = sacremoses + + +__all__ = ["BioGptTokenizer"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76ece6853b381b0853f52022078f9c580deb0f04 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_blenderbot import * + from .modeling_blenderbot import * + from .modeling_flax_blenderbot import * + from .modeling_tf_blenderbot import * + from .tokenization_blenderbot import * + from .tokenization_blenderbot_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/configuration_blenderbot.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/configuration_blenderbot.py new file mode 100644 index 0000000000000000000000000000000000000000..44287991375a2f83a76898ab10b1fbd14e31ffd2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/configuration_blenderbot.py @@ -0,0 +1,396 @@ +# coding=utf-8 +# Copyright 2021 The Facebook, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Blenderbot model configuration""" + +from collections import OrderedDict +from collections.abc import Mapping +from typing import Any, Optional + +from ... import PreTrainedTokenizer +from ...configuration_utils import PretrainedConfig +from ...file_utils import TensorType, is_torch_available +from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast +from ...onnx.utils import compute_effective_axis_dimension +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class BlenderbotConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BlenderbotModel`]. It is used to instantiate an + Blenderbot model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Blenderbot + [facebook/blenderbot-3B](https://huggingface.co/facebook/blenderbot-3B) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the Blenderbot model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`BlenderbotModel`] or [`TFBlenderbotModel`]. + d_model (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + max_position_embeddings (`int`, *optional*, defaults to 128): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) + for more details. + scale_embedding (`bool`, *optional*, defaults to `False`): + Scale embeddings by diving by sqrt(d_model). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models) + forced_eos_token_id (`int`, *optional*, defaults to 2): + The id of the token to force as the last generated token when `max_length` is reached. Usually set to + `eos_token_id`. + + Example: + + ```python + >>> from transformers import BlenderbotConfig, BlenderbotModel + + >>> # Initializing a Blenderbot facebook/blenderbot-3B style configuration + >>> configuration = BlenderbotConfig() + + >>> # Initializing a model (with random weights) from the facebook/blenderbot-3B style configuration + >>> model = BlenderbotModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "blenderbot" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=8008, + max_position_embeddings=128, + encoder_layers=2, + encoder_ffn_dim=10240, + encoder_attention_heads=32, + decoder_layers=24, + decoder_ffn_dim=10240, + decoder_attention_heads=32, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function="gelu", + d_model=2560, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=1, + scale_embedding=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + encoder_no_repeat_ngram_size=3, + forced_eos_token_id=2, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, + forced_eos_token_id=forced_eos_token_id, + **kwargs, + ) + + +class BlenderbotOnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task in ["default", "seq2seq-lm"]: + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ] + ) + if self.use_past: + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + elif self.task == "causal-lm": + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ] + ) + if self.use_past: + _, num_decoder_layers = self.num_layers + for i in range(num_decoder_layers): + common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + else: + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}), + ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}), + ] + ) + + return common_inputs + + @property + # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig.outputs + def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task in ["default", "seq2seq-lm"]: + common_outputs = super().outputs + else: + common_outputs = super(OnnxConfigWithPast, self).outputs + if self.use_past: + num_encoder_layers, _ = self.num_layers + for i in range(num_encoder_layers): + common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + return common_outputs + + def _generate_dummy_inputs_for_default_and_seq2seq_lm( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, seq_length, is_pair, framework + ) + # Generate decoder inputs + decoder_seq_length = seq_length if not self.use_past else 1 + decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, decoder_seq_length, is_pair, framework + ) + decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} + common_inputs = dict(**encoder_inputs, **decoder_inputs) + + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch, encoder_seq_length = common_inputs["input_ids"].shape + decoder_seq_length = common_inputs["decoder_input_ids"].shape[1] + num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads + encoder_shape = ( + batch, + num_encoder_attention_heads, + encoder_seq_length, + self._config.hidden_size // num_encoder_attention_heads, + ) + decoder_past_length = decoder_seq_length + decoder_shape = ( + batch, + num_decoder_attention_heads, + decoder_past_length, + self._config.hidden_size // num_decoder_attention_heads, + ) + common_inputs["decoder_attention_mask"] = torch.cat( + [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1 + ) + common_inputs["past_key_values"] = [] + _, num_decoder_layers = self.num_layers + + for _ in range(num_decoder_layers): + common_inputs["past_key_values"].append( + ( + torch.zeros(decoder_shape), + torch.zeros(decoder_shape), + torch.zeros(encoder_shape), + torch.zeros(encoder_shape), + ) + ) + return common_inputs + + def _generate_dummy_inputs_for_causal_lm( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, seq_length, is_pair, framework + ) + + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch, seqlen = common_inputs["input_ids"].shape + past_key_values_length = seqlen + _, num_decoder_layers = self.num_layers + num_encoder_attention_heads, _ = self.num_attention_heads + past_shape = ( + batch, + num_encoder_attention_heads, + past_key_values_length, + self._config.hidden_size // num_encoder_attention_heads, + ) + mask_dtype = common_inputs["attention_mask"].dtype + common_inputs["attention_mask"] = torch.cat( + [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 + ) + common_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_decoder_layers) + ] + return common_inputs + + # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering + def _generate_dummy_inputs_for_sequence_classification_and_question_answering( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + # Copied from OnnxConfig.generate_dummy_inputs + # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension( + batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0 + ) + + # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX + token_to_add = tokenizer.num_special_tokens_to_add(is_pair) + seq_length = compute_effective_axis_dimension( + seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add + ) + + # Generate dummy inputs according to compute batch and sequence + dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size + common_inputs = dict(tokenizer(dummy_input, return_tensors=framework)) + return common_inputs + + # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig.generate_dummy_inputs + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + if self.task in ["default", "seq2seq-lm"]: + common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + elif self.task == "causal-lm": + common_inputs = self._generate_dummy_inputs_for_causal_lm( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + else: + common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + return common_inputs + + # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._flatten_past_key_values_ + def _flatten_past_key_values_(self, flattened_output, name, idx, t): + if self.task in ["default", "seq2seq-lm"]: + flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t) + else: + flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_( + flattened_output, name, idx, t + ) + + def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str): + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + name = "past_key_values" if direction == "inputs" else "present" + _, num_decoder_layers = self.num_layers + + encoder_sequence = "past_encoder_sequence" + decoder_sequence = "past_decoder_sequence" if direction == "inputs" else "past_decoder_sequence + sequence" + + for i in range(num_decoder_layers): + inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch", 2: decoder_sequence} + inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch", 2: decoder_sequence} + inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch", 2: encoder_sequence} + inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch", 2: encoder_sequence} + + +__all__ = ["BlenderbotConfig", "BlenderbotOnnxConfig"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/modeling_blenderbot.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/modeling_blenderbot.py new file mode 100644 index 0000000000000000000000000000000000000000..3e25fad20d31aeb6c1bae19985e638f803a36a7b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/modeling_blenderbot.py @@ -0,0 +1,1595 @@ +# coding=utf-8 +# Copyright 2021 The Facebook, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Blenderbot model.""" + +import math +import os +import warnings +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) +from ...utils.deprecation import deprecate_kwarg +from ..blenderbot_small import BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel +from .configuration_blenderbot import BlenderbotConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.bart.modeling_bart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class BlenderbotLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + super().__init__(num_embeddings, embedding_dim) + + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) + + +# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->Blenderbot +class BlenderbotScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale + + +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Blenderbot +class BlenderbotAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[BlenderbotConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + + is_updated = False + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.layers[self.layer_idx].keys + value_states = curr_past_key_value.layers[self.layer_idx].values + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): + past_key_values.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT +class BlenderbotEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: BlenderbotConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BlenderbotAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool = False, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states, attn_weights + + +# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT +class BlenderbotDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: BlenderbotConfig, layer_idx: Optional[int] = None): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BlenderbotAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config, + layer_idx=layer_idx, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = BlenderbotAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_values (`Cache`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + past_key_values=past_key_values, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +@auto_docstring +class BlenderbotPreTrainedModel(PreTrainedModel): + config: BlenderbotConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + "decoder_input_ids": input_ids, + } + return dummy_inputs + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +class BlenderbotEncoder(BlenderbotPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`BlenderbotEncoderLayer`]. + + Args: + config: BlenderbotConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = BlenderbotScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) + + self.embed_positions = BlenderbotLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([BlenderbotEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(input_shape) + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != len(self.layers): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # add final layer norm + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class BlenderbotDecoder(BlenderbotPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BlenderbotDecoderLayer`] + + Args: + config: BlenderbotConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = BlenderbotScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) + + self.embed_positions = BlenderbotLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + self.layers = nn.ModuleList( + [BlenderbotDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)] + ) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position: Optional[torch.Tensor] = None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize `past_key_values` + if use_cache and past_key_values is None: + past_key_values = ( + EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) + if encoder_hidden_states is not None + else DynamicCache(config=self.config) + ) + if use_cache and isinstance(past_key_values, tuple): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) + + # embed positions + position_ids = self.embed_positions( + (batch_size, seq_length), past_key_values_length, position_ids=cache_position + ) + + hidden_states = inputs_embeds + position_ids + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + layer_outputs = decoder_layer( + hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add final layer norm + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@auto_docstring +class BlenderbotModel(BlenderbotPreTrainedModel): + _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] + + def __init__(self, config: BlenderbotConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = BlenderbotScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) + self.encoder = BlenderbotEncoder(config, self.shared) + self.decoder = BlenderbotDecoder(config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): + if pretrained_model_name_or_path == "facebook/blenderbot-90M": + warnings.warn( + "The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical" + " checkpoint `facebook/small_blenderbot-90M` with" + " `BlenderbotSmallModel.from_pretrained('facebook/small_blenderbot-90M')` instead.", + FutureWarning, + ) + return BlenderbotSmallModel.from_pretrained(pretrained_model_name_or_path) + + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Union[tuple, BaseModelOutput]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Blenderbot uses the `bos_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + Example: + + ```python + >>> from transformers import AutoTokenizer, BlenderbotModel + + >>> model = BlenderbotModel.from_pretrained("facebook/blenderbot-400M-distill") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") + + >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt") + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_input_ids) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 6, 1280] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The Blenderbot Model with a language modeling head. Can be used for summarization. + """ +) +class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMixin): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: BlenderbotConfig): + super().__init__(config) + self.model = BlenderbotModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): + if pretrained_model_name_or_path == "facebook/blenderbot-90M": + warnings.warn( + "The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical" + " checkpoint `facebook/small_blenderbot-90M` with" + " `BlenderbotSmallForConditionalGeneration.from_pretrained('facebook/small_blenderbot-90M')` instead.", + FutureWarning, + ) + return BlenderbotSmallForConditionalGeneration.from_pretrained(pretrained_model_name_or_path) + + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Union[tuple, BaseModelOutput]] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Blenderbot uses the `bos_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example conversation: + + ```python + >>> from transformers import AutoTokenizer, BlenderbotForConditionalGeneration + + >>> mname = "facebook/blenderbot-400M-distill" + >>> model = BlenderbotForConditionalGeneration.from_pretrained(mname) + >>> tokenizer = AutoTokenizer.from_pretrained(mname) + >>> UTTERANCE = "My friends are cool but they eat too many carbs." + >>> print("Human: ", UTTERANCE) + Human: My friends are cool but they eat too many carbs. + + >>> inputs = tokenizer([UTTERANCE], return_tensors="pt") + >>> reply_ids = model.generate(**inputs) + >>> print("Bot: ", tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]) + Bot: That's unfortunate. Are they trying to lose weight or are they just trying to be healthier? + + >>> REPLY = "I'm not sure" + >>> print("Human: ", REPLY) + Human: I'm not sure + + >>> NEXT_UTTERANCE = ( + ... "My friends are cool but they eat too many carbs. That's unfortunate. " + ... "Are they trying to lose weight or are they just trying to be healthier? " + ... " I'm not sure." + ... ) + >>> inputs = tokenizer([NEXT_UTTERANCE], return_tensors="pt") + >>> next_reply_ids = model.generate(**inputs) + >>> print("Bot: ", tokenizer.batch_decode(next_reply_ids, skip_special_tokens=True)[0]) + Bot: I see. Well, it's good that they're trying to change their eating habits. + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Blenderbot +class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = BlenderbotDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill +class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.model = BlenderbotDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[tuple, CausalLMOutputWithCrossAttentions]: + r""" + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, BlenderbotForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") + >>> model = BlenderbotForCausalLM.from_pretrained("facebook/blenderbot-400M-distill", add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size] + >>> list(logits.shape) == expected_shape + True + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +__all__ = [ + "BlenderbotForCausalLM", + "BlenderbotForConditionalGeneration", + "BlenderbotModel", + "BlenderbotPreTrainedModel", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/modeling_flax_blenderbot.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/modeling_flax_blenderbot.py new file mode 100644 index 0000000000000000000000000000000000000000..8b147211881b778d60236615ecc240812b2865b1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/modeling_flax_blenderbot.py @@ -0,0 +1,1508 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax Blenderbot model.""" + +import math +import random +from functools import partial +from typing import Callable, Optional + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax +from jax.random import PRNGKey + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_blenderbot import BlenderbotConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "BlenderbotConfig" +_CHECKPOINT_FOR_DOC = "facebook/blenderbot-400M-distill" + + +BLENDERBOT_START_DOCSTRING = r""" + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`BlenderbotConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +BLENDERBOT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +BLENDERBOT_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +BLENDERBOT_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + encoder_outputs (`tuple(tuple(jnp.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: + """ + Shift input ids one token to the right. + """ + shifted_input_ids = jnp.zeros_like(input_ids) + shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1]) + shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id) + + shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + return shifted_input_ids + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Blenderbot +class FlaxBlenderbotAttention(nn.Module): + config: BlenderbotConfig + embed_dim: int + num_heads: int + dropout: float = 0.0 + causal: bool = False + bias: bool = True + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self) -> None: + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + + dense = partial( + nn.Dense, + self.embed_dim, + use_bias=self.bias, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() + self.out_proj = dense() + + self.dropout_layer = nn.Dropout(rate=self.dropout) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states: jnp.ndarray, + key_value_states: Optional[jnp.ndarray] = None, + attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + ) -> tuple[jnp.ndarray]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states) + value_states = self.v_proj(key_value_states) + else: + # self_attention + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = self._merge_heads(attn_output) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer with MBart->Blenderbot +class FlaxBlenderbotEncoderLayer(nn.Module): + config: BlenderbotConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxBlenderbotAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.encoder_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + ) + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + self.fc1 = nn.Dense( + self.config.encoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + output_attentions: bool = True, + deterministic: bool = True, + ) -> tuple[jnp.ndarray]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection with Bart->Blenderbot +class FlaxBlenderbotEncoderLayerCollection(nn.Module): + config: BlenderbotConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxBlenderbotEncoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.encoder_layers) + ] + self.layerdrop = self.config.encoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions, + deterministic, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer with MBart->Blenderbot +class FlaxBlenderbotDecoderLayer(nn.Module): + config: BlenderbotConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxBlenderbotAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + causal=True, + dtype=self.dtype, + ) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.encoder_attn = FlaxBlenderbotAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.fc1 = nn.Dense( + self.config.decoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = True, + deterministic: bool = True, + ) -> tuple[jnp.ndarray]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection with Bart->Blenderbot +class FlaxBlenderbotDecoderLayerCollection(nn.Module): + config: BlenderbotConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxBlenderbotDecoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.decoder_layers) + ] + self.layerdrop = self.config.decoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): + layer_outputs = (None, None, None) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class FlaxBlenderbotEncoder(nn.Module): + config: BlenderbotConfig + embed_tokens: nn.Embed + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.d_model + self.padding_idx = self.config.pad_token_id + self.max_source_positions = self.config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 + + self.embed_positions = nn.Embed( + self.config.max_position_embeddings, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.layers = FlaxBlenderbotEncoderLayerCollection(self.config, self.dtype) + self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(position_ids) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_states = outputs[0] + last_hidden_states = self.layer_norm(last_hidden_states) + + # update the last element in `hidden_states` after applying `layernorm` above + hidden_states = None + if output_hidden_states: + hidden_states = outputs[1] + hidden_states = hidden_states[:-1] + (last_hidden_states,) + + if not return_dict: + outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=last_hidden_states, + hidden_states=hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxBlenderbotDecoder(nn.Module): + config: BlenderbotConfig + embed_tokens: nn.Embed + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.d_model + self.padding_idx = self.config.pad_token_id + self.max_target_positions = self.config.max_position_embeddings + self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 + + self.embed_positions = nn.Embed( + self.config.max_position_embeddings, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + ) + + self.layers = FlaxBlenderbotDecoderLayerCollection(self.config, self.dtype) + self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # embed positions + positions = self.embed_positions(position_ids) + + hidden_states = inputs_embeds + positions + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_states = outputs[0] + last_hidden_states = self.layer_norm(last_hidden_states) + + # update the last element in `hidden_states` after applying `layernorm` above + hidden_states = None + if output_hidden_states: + hidden_states = outputs[1] + hidden_states = hidden_states[:-1] + (last_hidden_states,) + + if not return_dict: + outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=last_hidden_states, + hidden_states=hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModule with Bart->Blenderbot +class FlaxBlenderbotModule(nn.Module): + config: BlenderbotConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + + self.encoder = FlaxBlenderbotEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) + self.decoder = FlaxBlenderbotDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared) + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel): + config_class = BlenderbotConfig + base_model_prefix: str = "model" + module_class: nn.Module = None + + def __init__( + self, + config: BlenderbotConfig, + input_shape: tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + # make sure initialization pass will work for FlaxBlenderbotForSequenceClassificationModule + input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) + attention_mask = jnp.ones_like(input_ids) + decoder_input_ids = input_ids + decoder_attention_mask = jnp.ones_like(input_ids) + + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape + ) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(BLENDERBOT_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=BlenderbotConfig) + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxBlenderbotForConditionalGeneration + + >>> model = FlaxBlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax") + >>> encoder_outputs = model.encode(**inputs) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, position_ids, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + @add_start_docstrings(BLENDERBOT_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=BlenderbotConfig + ) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: Optional[dict] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> import jax.numpy as jnp + >>> from transformers import AutoTokenizer, FlaxBlenderbotForConditionalGeneration + + >>> model = FlaxBlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> last_decoder_hidden_states = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxBlenderbotAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + @add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # prepare decoder inputs + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id + ) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + if decoder_position_ids is None: + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + +@add_start_docstrings( + "The bare MBart Model transformer outputting raw hidden-states without any specific head on top.", + BLENDERBOT_START_DOCSTRING, +) +class FlaxBlenderbotModel(FlaxBlenderbotPreTrainedModel): + config: BlenderbotConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + module_class = FlaxBlenderbotModule + + +append_call_sample_docstring(FlaxBlenderbotModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule with Bart->Blenderbot +class FlaxBlenderbotForConditionalGenerationModule(nn.Module): + config: BlenderbotConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.model = FlaxBlenderbotModule(config=self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.model.shared.num_embeddings, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + position_ids=position_ids, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = self.model.variables["params"]["shared"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return output + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + "The Blenderbot Model with a language modeling head. Can be used for summarization.", BLENDERBOT_START_DOCSTRING +) +class FlaxBlenderbotForConditionalGeneration(FlaxBlenderbotPreTrainedModel): + module_class = FlaxBlenderbotForConditionalGenerationModule + dtype: jnp.dtype = jnp.float32 + + @add_start_docstrings(BLENDERBOT_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=BlenderbotConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: Optional[dict] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: Optional[dict] = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> import jax.numpy as jnp + >>> from transformers import AutoTokenizer, FlaxBlenderbotForConditionalGeneration + + >>> model = FlaxBlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax") + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxBlenderbotAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + outputs = decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = module.model.variables["params"]["shared"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = module.lm_head(hidden_states) + + lm_logits += module.final_logits_bias + return lm_logits, outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs=None, + **kwargs, + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + "decoder_position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 + return model_kwargs + + +FLAX_BLENDERBOT_CONDITIONAL_GENERATION_DOCSTRING = r""" + Returns: + + Conversation example:: + + ```py + >>> from transformers import AutoTokenizer, FlaxBlenderbotForConditionalGeneration + + >>> model = FlaxBlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") + + >>> UTTERANCE = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([UTTERANCE], max_length=1024, return_tensors="np") + + >>> # Generate Reply + >>> reply_ids = model.generate(inputs["input_ids"], num_beams=4, max_length=5, early_stopping=True).sequences + >>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in reply_ids]) + ``` +""" + +overwrite_call_docstring( + FlaxBlenderbotForConditionalGeneration, + BLENDERBOT_INPUTS_DOCSTRING + FLAX_BLENDERBOT_CONDITIONAL_GENERATION_DOCSTRING, +) +append_replace_return_docstrings( + FlaxBlenderbotForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) + + +__all__ = ["FlaxBlenderbotForConditionalGeneration", "FlaxBlenderbotModel", "FlaxBlenderbotPreTrainedModel"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/modeling_tf_blenderbot.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/modeling_tf_blenderbot.py new file mode 100644 index 0000000000000000000000000000000000000000..78f4f6a6761e71087b5d232a42cb137595af9702 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -0,0 +1,1557 @@ +# coding=utf-8 +# Copyright 2021 The Facebook, Inc and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 Blenderbot model.""" + +from __future__ import annotations + +import os +import random +import warnings + +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPastAndCrossAttentions, + TFSeq2SeqLMOutput, + TFSeq2SeqModelOutput, +) + +# Public API +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFPreTrainedModel, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_blenderbot import BlenderbotConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/blenderbot-400M-distill" +_CONFIG_FOR_DOC = "BlenderbotConfig" + + +LARGE_NEGATIVE = -1e8 + + +# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right +def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): + pad_token_id = tf.cast(pad_token_id, input_ids.dtype) + decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) + start_tokens = tf.fill( + (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype) + ) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, + tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)), + shifted_input_ids, + ) + + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) + + return shifted_input_ids + + +# Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz = input_ids_shape[0] + tgt_len = input_ids_shape[1] + mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE + mask_cond = tf.range(shape_list(mask)[-1]) + + mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) + + if past_key_values_length > 0: + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) + + return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) + + +# Copied from transformers.models.bart.modeling_tf_bart._expand_mask +def _expand_mask(mask: tf.Tensor, tgt_len: int | None = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + src_len = shape_list(mask)[1] + tgt_len = tgt_len if tgt_len is not None else src_len + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) + + return (one_cst - expanded_mask) * LARGE_NEGATIVE + + +class TFBlenderbotLearnedPositionalEmbedding(keras.layers.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs): + super().__init__(num_embeddings, embedding_dim, **kwargs) + + def call( + self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: tf.Tensor | None = None + ): + """Input is expected to be of size [bsz x seqlen].""" + if position_ids is None: + seq_len = input_shape[1] + position_ids = tf.range(seq_len, delta=1, name="range") + position_ids += past_key_values_length + + return super().call(tf.cast(position_ids, dtype=tf.int32)) + + +# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Blenderbot +class TFBlenderbotAttention(keras.layers.Layer): + """Multi-headed attention from "Attention Is All You Need""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.dropout = keras.layers.Dropout(dropout) + self.head_dim = embed_dim // num_heads + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") + self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") + self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") + self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") + + def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): + return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) + + def call( + self, + hidden_states: tf.Tensor, + key_value_states: tf.Tensor | None = None, + past_key_value: tuple[tuple[tf.Tensor]] | None = None, + attention_mask: tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + training: bool | None = False, + ) -> tuple[tf.Tensor, tf.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = shape_list(hidden_states) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = tf.concat([past_key_value[0], key_states], axis=2) + value_states = tf.concat([past_key_value[1], value_states], axis=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) + key_states = tf.reshape(key_states, proj_shape) + value_states = tf.reshape(value_states, proj_shape) + + src_len = shape_list(key_states)[1] + attn_weights = tf.matmul(query_states, key_states, transpose_b=True) + + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {shape_list(attn_weights)}" + ), + ) + + if attention_mask is not None: + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {shape_list(attention_mask)}" + ), + ) + + attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_weights = stable_softmax(attn_weights, axis=-1) + + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=( + f"Head mask for a single layer should be of size {(self.num_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_probs = self.dropout(attn_weights, training=training) + attn_output = tf.matmul(attn_probs, value_states) + + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {shape_list(attn_output)}" + ), + ) + + attn_output = tf.transpose( + tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) + ) + attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) + + attn_output = self.out_proj(attn_output) + attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + + return attn_output, attn_weights, past_key_value + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build([None, None, self.embed_dim]) + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build([None, None, self.embed_dim]) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build([None, None, self.embed_dim]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.embed_dim]) + + +# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartEncoderLayer with MBart->Blenderbot +class TFBlenderbotEncoderLayer(keras.layers.Layer): + def __init__(self, config: BlenderbotConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.d_model + self.self_attn = TFBlenderbotAttention( + self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn" + ) + self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.dropout = keras.layers.Dropout(config.dropout) + self.activation_fn = get_tf_activation(config.activation_function) + self.activation_dropout = keras.layers.Dropout(config.activation_dropout) + self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1") + self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + layer_head_mask: tf.Tensor, + training: bool | None = False, + ): + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* + attention_mask (`tf.Tensor`): attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size + *(encoder_attention_heads,)* + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, self_attn_weights, _ = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask + ) + + tf.debugging.assert_equal( + shape_list(hidden_states), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", + ) + + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states, training=training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + return hidden_states, self_attn_weights + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "self_attn_layer_norm", None) is not None: + with tf.name_scope(self.self_attn_layer_norm.name): + self.self_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build([None, None, self.embed_dim]) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build([None, None, self.config.encoder_ffn_dim]) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.embed_dim]) + + +# Copied from transformers.models.mbart.modeling_tf_mbart.TFMBartDecoderLayer with MBart->Blenderbot +class TFBlenderbotDecoderLayer(keras.layers.Layer): + def __init__(self, config: BlenderbotConfig, **kwargs): + super().__init__(**kwargs) + self.embed_dim = config.d_model + self.self_attn = TFBlenderbotAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + name="self_attn", + is_decoder=True, + ) + self.dropout = keras.layers.Dropout(config.dropout) + self.activation_fn = get_tf_activation(config.activation_function) + self.activation_dropout = keras.layers.Dropout(config.activation_dropout) + + self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm") + self.encoder_attn = TFBlenderbotAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + name="encoder_attn", + is_decoder=True, + ) + self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm") + self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1") + self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + encoder_hidden_states: tf.Tensor | None = None, + encoder_attention_mask: tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + cross_attn_layer_head_mask: tf.Tensor | None = None, + past_key_value: tuple[tf.Tensor] | None = None, + training: bool | None = False, + ) -> tuple[tf.Tensor, tf.Tensor, tuple[tuple[tf.Tensor]]]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)* + attention_mask (`tf.Tensor`): attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + encoder_hidden_states (`tf.Tensor`): + cross attention input to the layer of shape *(batch, seq_len, embed_dim)* + encoder_attention_mask (`tf.Tensor`): encoder attention mask of size + *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. + layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size + *(decoder_attention_heads,)* + cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module. + *(decoder_attention_heads,)* + past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout(hidden_states, training=training) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = residual + hidden_states + + return ( + hidden_states, + self_attn_weights, + cross_attn_weights, + present_key_value, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "self_attn_layer_norm", None) is not None: + with tf.name_scope(self.self_attn_layer_norm.name): + self.self_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "encoder_attn", None) is not None: + with tf.name_scope(self.encoder_attn.name): + self.encoder_attn.build(None) + if getattr(self, "encoder_attn_layer_norm", None) is not None: + with tf.name_scope(self.encoder_attn_layer_norm.name): + self.encoder_attn_layer_norm.build([None, None, self.embed_dim]) + if getattr(self, "fc1", None) is not None: + with tf.name_scope(self.fc1.name): + self.fc1.build([None, None, self.embed_dim]) + if getattr(self, "fc2", None) is not None: + with tf.name_scope(self.fc2.name): + self.fc2.build([None, None, self.config.decoder_ffn_dim]) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.embed_dim]) + + +class TFBlenderbotPreTrainedModel(TFPreTrainedModel): + config_class = BlenderbotConfig + base_model_prefix = "model" + + +BLENDERBOT_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`BlenderbotConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +BLENDERBOT_GENERATION_EXAMPLE = r""" + Conversation example:: + + ```py + >>> from transformers import AutoTokenizer, TFBlenderbotForConditionalGeneration + + >>> mname = "facebook/blenderbot-400M-distill" + >>> model = TFBlenderbotForConditionalGeneration.from_pretrained(mname) + >>> tokenizer = AutoTokenizer.from_pretrained(mname) + >>> UTTERANCE = "My friends are cool but they eat too many carbs." + >>> print("Human: ", UTTERANCE) + + >>> inputs = tokenizer([UTTERANCE], return_tensors="tf") + >>> reply_ids = model.generate(**inputs) + >>> print("Bot: ", tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]) + + >>> REPLY = "I'm not sure" + >>> print("Human: ", REPLY) + >>> NEXT_UTTERANCE = ( + ... "My friends are cool but they eat too many carbs. That's unfortunate. " + ... "Are they trying to lose weight or are they just trying to be healthier? " + ... " I'm not sure." + ... ) + >>> inputs = tokenizer([NEXT_UTTERANCE], return_tensors="tf") + >>> next_reply_ids = model.generate(**inputs) + >>> print("Bot: ", tokenizer.batch_decode(next_reply_ids, skip_special_tokens=True)[0]) + ``` +""" + +BLENDERBOT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + Blenderbot uses the `bos_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + decoder_position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tf.FloatTensor`, *optional*): + hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + of shape `(batch_size, sequence_length, hidden_size)` is a sequence of + past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@keras_serializable +class TFBlenderbotEncoder(keras.layers.Layer): + config_class = BlenderbotConfig + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`TFBlenderbotEncoderLayer`]. + + Args: + config: BlenderbotConfig + """ + + def __init__(self, config: BlenderbotConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.dropout = keras.layers.Dropout(config.dropout) + self.layerdrop = config.encoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 + + self.embed_tokens = embed_tokens + self.embed_positions = TFBlenderbotLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + name="embed_positions", + ) + self.layers = [TFBlenderbotEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] + self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") + + def get_embed_tokens(self): + return self.embed_tokens + + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + + @unpack_inputs + def call( + self, + input_ids=None, + inputs_embeds=None, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + ): + """ + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value + in the config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. This argument can be used only in eager mode, in graph mode the value in the config + will be used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used + in eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). + """ + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(input_shape) + hidden_states = inputs_embeds + embed_pos + hidden_states = self.dropout(hidden_states, training=training) + + # check attention mask and invert + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask) + else: + attention_mask = None + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + tf.debugging.assert_equal( + shape_list(head_mask)[0], + len(self.layers), + message=( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {shape_list(head_mask)[0]}." + ), + ) + + # encoder layers + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if training and (dropout_probability < self.layerdrop): # skip the layer + continue + + hidden_states, attn = encoder_layer( + hidden_states, + attention_mask, + head_mask[idx] if head_mask is not None else None, + ) + + if output_attentions: + all_attentions += (attn,) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_positions", None) is not None: + with tf.name_scope(self.embed_positions.name): + self.embed_positions.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.d_model]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFBlenderbotDecoder(keras.layers.Layer): + config_class = BlenderbotConfig + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFBlenderbotDecoderLayer`] + + Args: + config: BlenderbotConfig + embed_tokens: output embedding + """ + + def __init__(self, config: BlenderbotConfig, embed_tokens: keras.layers.Embedding | None = None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.padding_idx = config.pad_token_id + self.embed_tokens = embed_tokens + self.layerdrop = config.decoder_layerdrop + self.embed_positions = TFBlenderbotLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + name="embed_positions", + ) + self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0 + self.layers = [TFBlenderbotDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)] + self.layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm") + + self.dropout = keras.layers.Dropout(config.dropout) + + def get_embed_tokens(self): + return self.embed_tokens + + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + + @unpack_inputs + def call( + self, + input_ids=None, + inputs_embeds=None, + attention_mask=None, + position_ids=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + ): + r""" + Args: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. This argument can be used only in eager mode, in graph mode the value + in the config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. This argument can be used only in eager mode, in graph mode the value in the config + will be used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used + in eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). + """ + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 + + # embed positions + if position_ids is None: + positions = self.embed_positions(input_shape, past_key_values_length) + else: + positions = self.embed_positions(input_shape, position_ids=position_ids) + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim) + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + hidden_states = inputs_embeds + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + else: + combined_attention_mask = _expand_mask( + tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] + ) + + if attention_mask is not None: + combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) + + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) + + hidden_states = hidden_states + positions + hidden_states = self.dropout(hidden_states, training=training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None + present_key_values = () if use_cache else None + + # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired + for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: + if attn_mask is not None: + tf.debugging.assert_equal( + shape_list(attn_mask)[0], + len(self.layers), + message=( + f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for" + f" {shape_list(attn_mask)[0]}." + ), + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + + if training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + cross_attn_layer_head_mask=cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + past_key_value=past_key_value, + ) + + if use_cache: + present_key_values += (present_key_value,) + + if output_attentions: + all_self_attns += (layer_self_attn,) + + if encoder_hidden_states is not None: + all_cross_attns += (layer_cross_attn,) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns + else: + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_positions", None) is not None: + with tf.name_scope(self.embed_positions.name): + self.embed_positions.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.d_model]) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFBlenderbotMainLayer(keras.layers.Layer): + config_class = BlenderbotConfig + + def __init__(self, config: BlenderbotConfig, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.shared = keras.layers.Embedding( + input_dim=config.vocab_size, + output_dim=config.d_model, + embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std), + name="model.shared", + ) + # Additional attribute to specify the expected name scope of the layer (for loading/storing weights) + self.shared.load_weight_prefix = "model.shared" + + self.encoder = TFBlenderbotEncoder(config, self.shared, name="encoder") + self.decoder = TFBlenderbotDecoder(config, self.shared, name="decoder") + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + @unpack_inputs + def call( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + decoder_position_ids=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs: tuple | TFBaseModelOutput | None = None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + **kwargs, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): + encoder_outputs = TFBaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False + elif not return_dict and not isinstance(encoder_outputs, tuple): + encoder_outputs = encoder_outputs.to_tuple() + + decoder_outputs = self.decoder( + decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return TFSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + # The shared/tied weights expect to be in the model base namespace + # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than + # the current one. + with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"): + self.shared.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + + +@add_start_docstrings( + "The bare BLENDERBOT Model outputting raw hidden-states without any specific head on top.", + BLENDERBOT_START_DOCSTRING, +) +class TFBlenderbotModel(TFBlenderbotPreTrainedModel): + def __init__(self, config: BlenderbotConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.model = TFBlenderbotMainLayer(config, name="model") + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None, *model_args, **kwargs): + if pretrained_model_name_or_path == "facebook/blenderbot-90M": + from ..blenderbot_small import TFBlenderbotSmallModel + + warnings.warn( + "The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical" + " checkpoint `facebook/small_blenderbot-90M` with" + " `TFBlenderbotSmallForConditionalGeneration.from_pretrained('facebook/small_blenderbot-90M')`" + " instead.", + FutureWarning, + ) + return TFBlenderbotSmallModel.from_pretrained(pretrained_model_name_or_path) + + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + @unpack_inputs + @add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSeq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: tf.Tensor | None = None, + attention_mask: tf.Tensor | None = None, + decoder_input_ids: tf.Tensor | None = None, + decoder_attention_mask: tf.Tensor | None = None, + decoder_position_ids: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + decoder_head_mask: tf.Tensor | None = None, + cross_attn_head_mask: tf.Tensor | None = None, + encoder_outputs: tuple | TFBaseModelOutput | None = None, + past_key_values: list[tf.Tensor] | None = None, + inputs_embeds: tf.Tensor | None = None, + decoder_inputs_embeds: tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool | None = False, + **kwargs, + ) -> tuple[tf.Tensor] | TFSeq2SeqModelOutput: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + # Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqModelOutput( + last_hidden_state=output.last_hidden_state, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + + +# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer +class BiasLayer(keras.layers.Layer): + """ + Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis, + so all weights have to be registered in a layer. + """ + + def __init__(self, shape, initializer, trainable, name, **kwargs): + super().__init__(name=name, **kwargs) + # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of + # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see: + # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214 + self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable) + + def call(self, x): + return x + self.bias + + +@add_start_docstrings( + "The BLENDERBOT Model with a language modeling head. Can be used for summarization.", + BLENDERBOT_START_DOCSTRING, +) +class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausalLanguageModelingLoss): + _keys_to_ignore_on_load_unexpected = [ + r"model.encoder.embed_tokens.weight", + r"model.decoder.embed_tokens.weight", + ] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.model = TFBlenderbotMainLayer(config, name="model") + self.use_cache = config.use_cache + # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency. + self.bias_layer = BiasLayer( + name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False + ) + + def get_decoder(self): + return self.model.decoder + + def get_encoder(self): + return self.model.encoder + + def get_output_embeddings(self): + return self.get_input_embeddings() + + def set_output_embeddings(self, value): + self.set_input_embeddings(value) + + def get_bias(self): + return {"final_logits_bias": self.bias_layer.bias} + + def set_bias(self, value): + # Replaces the existing layers containing bias for correct (de)serialization. + vocab_size = value["final_logits_bias"].shape[-1] + self.bias_layer = BiasLayer( + name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False + ) + self.bias_layer.bias.assign(value["final_logits_bias"]) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None, *model_args, **kwargs): + if pretrained_model_name_or_path == "facebook/blenderbot-90M": + from ..blenderbot_small import TFBlenderbotSmallForConditionalGeneration + + warnings.warn( + "The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical" + " checkpoint `facebook/small_blenderbot-90M` with" + " `TFBlenderbotSmallForConditionalGeneration.from_pretrained('facebook/small_blenderbot-90M')`" + " instead.", + FutureWarning, + ) + return TFBlenderbotSmallForConditionalGeneration.from_pretrained(pretrained_model_name_or_path) + + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + @unpack_inputs + @add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_end_docstrings(BLENDERBOT_GENERATION_EXAMPLE) + def call( + self, + input_ids: tf.Tensor | None = None, + attention_mask: tf.Tensor | None = None, + decoder_input_ids: tf.Tensor | None = None, + decoder_attention_mask: tf.Tensor | None = None, + decoder_position_ids: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + decoder_head_mask: tf.Tensor | None = None, + cross_attn_head_mask: tf.Tensor | None = None, + encoder_outputs: tuple | TFBaseModelOutput | None = None, + past_key_values: list[tf.Tensor] | None = None, + inputs_embeds: tf.Tensor | None = None, + decoder_inputs_embeds: tf.Tensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: tf.Tensor | None = None, + training: bool | None = False, + ) -> tuple[tf.Tensor] | TFSeq2SeqLMOutput: + r""" + labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + """ + if labels is not None: + labels = tf.where( + labels == self.config.pad_token_id, + tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), + labels, + ) + use_cache = False + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True) + lm_logits = self.bias_layer(lm_logits) + masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + return TFSeq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, # index 1 of d outputs + decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs + decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs + cross_attentions=outputs.cross_attentions, # index 4 of d outputs + encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs + encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out + encoder_attentions=outputs.encoder_attentions, # 2 of e out + ) + + # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output + def serving_output(self, output): + pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None + dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None + dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None + enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None + enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None + + return TFSeq2SeqLMOutput( + logits=output.logits, + past_key_values=pkv, + decoder_hidden_states=dec_hs, + decoder_attentions=dec_attns, + cross_attentions=cross_attns, + encoder_last_hidden_state=output.encoder_last_hidden_state, + encoder_hidden_states=enc_hs, + encoder_attentions=enc_attns, + ) + + # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + if decoder_attention_mask is not None: # xla + decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:] + elif past_key_values is not None: # no xla + past_key_values + decoder_position_ids = past_key_values[0][0].shape[2] + else: # no xla + no past_key_values + decoder_position_ids = tf.range(decoder_input_ids.shape[1]) + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + if getattr(self, "bias_layer", None) is not None: + with tf.name_scope(self.bias_layer.name): + self.bias_layer.build(None) + + +__all__ = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel", "TFBlenderbotPreTrainedModel"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/tokenization_blenderbot.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/tokenization_blenderbot.py new file mode 100644 index 0000000000000000000000000000000000000000..76719fa254948b7fa4a9ea151bfc2fcdba07131c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/tokenization_blenderbot.py @@ -0,0 +1,410 @@ +# coding=utf-8 +# Copyright 2021 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for Blenderbot.""" + +import json +import os +from functools import lru_cache +from typing import Optional + +import regex as re + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", + "tokenizer_config_file": "tokenizer_config.json", +} + + +@lru_cache +# Copied from transformers.models.roberta.tokenization_roberta.bytes_to_unicode +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +# Copied from transformers.models.roberta.tokenization_roberta.get_pairs +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class BlenderbotTokenizer(PreTrainedTokenizer): + """ + Constructs a Blenderbot tokenizer, derived from the GPT-2 tokenizer, using byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import BlenderbotTokenizer + + >>> tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-3B") + >>> tokenizer.add_prefix_space = False + >>> tokenizer("Hello world")["input_ids"] + [47, 921, 86, 1085, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [6950, 1085, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one). + + + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (Blenderbot tokenizer detect beginning of words by the preceding space). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.__init__ with Roberta->Blenderbot, RoBERTa->Blenderbot + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + **kwargs, + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = ( + AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False) + if isinstance(mask_token, str) + else mask_token + ) + + # these special tokens are not part of the vocab.json, let's add them in the correct order + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + bpe_merges = merges_handle.read().split("\n")[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_merges] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.add_prefix_space = add_prefix_space + + # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.vocab_size with Roberta->Blenderbot, RoBERTa->Blenderbot + def vocab_size(self): + return len(self.encoder) + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.get_vocab with Roberta->Blenderbot, RoBERTa->Blenderbot + def get_vocab(self): + vocab = dict(self.encoder).copy() + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.bpe with Roberta->Blenderbot, RoBERTa->Blenderbot + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._tokenize with Roberta->Blenderbot, RoBERTa->Blenderbot + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._convert_token_to_id with Roberta->Blenderbot, RoBERTa->Blenderbot + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._convert_id_to_token with Roberta->Blenderbot, RoBERTa->Blenderbot + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.convert_tokens_to_string with Roberta->Blenderbot, RoBERTa->Blenderbot + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.save_vocabulary with Roberta->Blenderbot, RoBERTa->Blenderbot + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.get_special_tokens_mask with Roberta->Blenderbot, RoBERTa->Blenderbot + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.create_token_type_ids_from_sequences with Roberta->Blenderbot, RoBERTa->Blenderbot + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. Blenderbot does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.prepare_for_tokenization with Roberta->Blenderbot, RoBERTa->Blenderbot + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) + if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): + text = " " + text + return (text, kwargs) + + def build_inputs_with_special_tokens(self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A Blenderbot sequence has the following format: + - single sequence: ` X ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (`list[int]`, *optional*): + Will be ignored + Returns: + `list[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + return token_ids_0 + [self.eos_token_id] + + +__all__ = ["BlenderbotTokenizer"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/tokenization_blenderbot_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/tokenization_blenderbot_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..0b84200e02d5c7f89141686c4e3c63c8d7fefc14 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/blenderbot/tokenization_blenderbot_fast.py @@ -0,0 +1,284 @@ +# coding=utf-8 +# Copyright 2021 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization class for Blenderbot.""" + +import json +from typing import Optional + +from tokenizers import processors + +from ...tokenization_utils_base import AddedToken, BatchEncoding +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_blenderbot import BlenderbotTokenizer + + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", + "tokenizer_config_file": "tokenizer_config.json", +} + + +class BlenderbotTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" Blenderbot tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2 + tokenizer, using byte-level Byte-Pair-Encoding. + + This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import BlenderbotTokenizerFast + + >>> tokenizer = BlenderbotTokenizerFast.from_pretrained("facebook/blenderbot-3B") + >>> tokenizer("Hello world")["input_ids"] + [6950, 1085, 2] + + >>> tokenizer(" Hello world")["input_ids"] + [6950, 1085, 2] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`. + + + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (Blenderbot tokenizer detect beginning of words by the preceding space). + trim_offsets (`bool`, *optional*, defaults to `True`): + Whether the post processing step should trim offsets to avoid including whitespaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = BlenderbotTokenizer + + # Copied from transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast.__init__ with Roberta->Blenderbot, RoBERTa->Blenderbot + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + trim_offsets=True, + **kwargs, + ): + mask_token = ( + AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False) + if isinstance(mask_token, str) + else mask_token + ) + super().__init__( + vocab_file, + merges_file, + tokenizer_file=tokenizer_file, + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, + **kwargs, + ) + + tokenizer_component = "post_processor" + tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) + if tokenizer_component_instance: + state = json.loads(tokenizer_component_instance.__getstate__()) + + # The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class` + if "sep" in state: + state["sep"] = tuple(state["sep"]) + if "cls" in state: + state["cls"] = tuple(state["cls"]) + + changes_to_apply = False + + if state.get("add_prefix_space", add_prefix_space) != add_prefix_space: + state["add_prefix_space"] = add_prefix_space + changes_to_apply = True + + if state.get("trim_offsets", trim_offsets) != trim_offsets: + state["trim_offsets"] = trim_offsets + changes_to_apply = True + + if changes_to_apply: + component_class = getattr(processors, state.pop("type")) + new_value = component_class(**state) + setattr(self.backend_tokenizer, tokenizer_component, new_value) + + @property + # Copied from transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast.mask_token with Roberta->Blenderbot, RoBERTa->Blenderbot + def mask_token(self) -> str: + """ + `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not + having been set. + + Blenderbot tokenizer has a special mask token to be usable in the fill-mask pipeline. The mask token will greedily + comprise the space before the **. + """ + if self._mask_token is None: + if self.verbose: + logger.error("Using mask_token, but it is not set yet.") + return None + return str(self._mask_token) + + @mask_token.setter + def mask_token(self, value): + """ + Overriding the default behavior of the mask token to have it eat the space before it. + + This is needed to preserve backward compatibility with all the previously used models based on Roberta. + """ + # Mask token behave like a normal word, i.e. include the space before it + # So we set lstrip to True + value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value + self._mask_token = value + + # Copied from transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast._batch_encode_plus with Roberta->Blenderbot, RoBERTa->Blenderbot + def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + assert self.add_prefix_space or not is_split_into_words, ( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._batch_encode_plus(*args, **kwargs) + + # Copied from transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast._encode_plus with Roberta->Blenderbot, RoBERTa->Blenderbot + def _encode_plus(self, *args, **kwargs) -> BatchEncoding: + is_split_into_words = kwargs.get("is_split_into_words", False) + + assert self.add_prefix_space or not is_split_into_words, ( + f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True " + "to use it with pretokenized inputs." + ) + + return super()._encode_plus(*args, **kwargs) + + # Copied from transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast.save_vocabulary with Roberta->Blenderbot, RoBERTa->Blenderbot + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + # Copied from transformers.models.roberta.tokenization_roberta_fast.RobertaTokenizerFast.create_token_type_ids_from_sequences with Roberta->Blenderbot, RoBERTa->Blenderbot + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. Blenderbot does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def build_inputs_with_special_tokens(self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A Blenderbot sequence has the following format: + - single sequence: ` X ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (`list[int]`, *optional*): + Will be ignored + Returns: + `list[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + return token_ids_0 + [self.eos_token_id] + + +__all__ = ["BlenderbotTokenizerFast"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/byt5/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/byt5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb726942b0f16105f8a5a5f7c661485951d7ccc7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/byt5/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .tokenization_byt5 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/byt5/tokenization_byt5.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/byt5/tokenization_byt5.py new file mode 100644 index 0000000000000000000000000000000000000000..2a9804db1014a963fb2054083f2db2782a41016d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/byt5/tokenization_byt5.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# Copyright 2021 T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for model ByT5.""" + +import warnings +from typing import Optional + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ByT5Tokenizer(PreTrainedTokenizer): + """ + Construct a ByT5 tokenizer. ByT5 simply uses raw bytes utf-8 encoding. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + extra_ids (`int`, *optional*, defaults to 125): + Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are + accessible as "" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are + indexed from the end of the vocabulary up to beginning ("" is the last token in the vocabulary + like in ByT5 preprocessing see + [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)). + additional_special_tokens (`list[str]`, *optional*): + Additional special tokens used by the tokenizer. + """ + + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + eos_token="", + unk_token="", + pad_token="", + extra_ids=125, + additional_special_tokens=None, + **kwargs, + ) -> None: + # Add extra_ids to the special token list + if extra_ids > 0 and additional_special_tokens is None: + additional_special_tokens = [f"" for i in range(extra_ids)] + elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0: + # Check that we have the right number of extra_id special tokens + extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens))) + if extra_tokens != extra_ids: + raise ValueError( + f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are" + " provided to ByT5Tokenizer. In this case the additional_special_tokens must include the" + " extra_ids tokens" + ) + + pad_token = AddedToken(pad_token, lstrip=True, rstrip=True) if isinstance(pad_token, str) else pad_token + # we force left and right stripping for backward compatibility. The byt5tests depend on this. + eos_token = AddedToken(eos_token, lstrip=True, rstrip=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=True, rstrip=True) if isinstance(unk_token, str) else unk_token + # unk token needs to be in the vocab with correct index + self._added_tokens_decoder = {0: pad_token, 1: eos_token, 2: unk_token} + self.offset = len(self._added_tokens_decoder) + self._utf_vocab_size = 2**8 # utf is 8 bits + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + extra_ids=0, + additional_special_tokens=additional_special_tokens, # TODO extra ids are not used :sweatywmile: + **kwargs, + ) + + @property + def vocab_size(self): + return self._utf_vocab_size + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)} + vocab.update(self.added_tokens_encoder) + return vocab + + def get_special_tokens_mask( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False + ) -> list[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + # normal case: some special tokens + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + def _add_eos_if_not_present(self, token_ids: list[int]) -> list[int]: + """Do not add eos again if user already added it.""" + if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: + warnings.warn( + f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" + " eos tokens being added." + ) + return token_ids + else: + return token_ids + [self.eos_token_id] + + def create_token_type_ids_from_sequences( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. ByT5 does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`list[int]`): + List of IDs. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`list[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = self._add_eos_if_not_present(token_ids_0) + if token_ids_1 is None: + return token_ids_0 + else: + token_ids_1 = self._add_eos_if_not_present(token_ids_1) + return token_ids_0 + token_ids_1 + + def _tokenize(self, text: str) -> list[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + tokens = [chr(i) for i in text.encode("utf-8")] + return tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + + if len(token) != 1: + token_id = None + else: + token_id = ord(token) + self.offset + + return token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = chr(index - self.offset) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + bstring = b"" + for token in tokens: + if token in self.added_tokens_decoder: + tok_string = self.added_tokens_decoder[token].encode("utf-8") + elif token in self.added_tokens_encoder: + tok_string = token.encode("utf-8") + else: + tok_string = bytes([ord(token)]) + bstring += tok_string + string = bstring.decode("utf-8", errors="ignore") + return string + + # ByT5Tokenizer has no vocab file + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]: + return () + + +__all__ = ["ByT5Tokenizer"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/clap/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/clap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d54ee86aecef2cbe5b9bfdee321a0375d977880 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/clap/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_clap import * + from .feature_extraction_clap import * + from .modeling_clap import * + from .processing_clap import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/clap/configuration_clap.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/clap/configuration_clap.py new file mode 100644 index 0000000000000000000000000000000000000000..900e8d373f5ad15d3d90f2da655abe10f3279f85 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/clap/configuration_clap.py @@ -0,0 +1,382 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CLAP model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ClapTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ClapTextModel`]. It is used to instantiate a CLAP + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the CLAP + [calp-hsat-fused](https://huggingface.co/laion/clap-hsat-fused) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the CLAP model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ClapTextModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"relu"`, + `"relu"`, `"silu"` and `"relu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`ClapTextModel`]. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + projection_hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the projection layer. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + projection_dim (`int`, *optional*, defaults to 512) + Dimension of the projection head of the `ClapTextModelWithProjection`. + + Examples: + + ```python + >>> from transformers import ClapTextConfig, ClapTextModel + + >>> # Initializing a CLAP text configuration + >>> configuration = ClapTextConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = ClapTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "clap_text_model" + base_config_key = "text_config" + + def __init__( + self, + vocab_size=50265, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=514, + type_vocab_size=1, + initializer_factor=1.0, + layer_norm_eps=1e-12, + projection_dim=512, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + position_embedding_type="absolute", + use_cache=True, + projection_hidden_act="relu", + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_factor = initializer_factor + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.projection_hidden_act = projection_hidden_act + self.projection_dim = projection_dim + + +class ClapAudioConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ClapAudioModel`]. It is used to instantiate a + CLAP audio encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the audio encoder of the CLAP + [laion/clap-htsat-fused](https://huggingface.co/laion/clap-htsat-fused) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + window_size (`int`, *optional*, defaults to 8): + Image size of the spectrogram + num_mel_bins (`int`, *optional*, defaults to 64): + Number of mel features used per frames. Should correspond to the value used in the `ClapProcessor` class. + spec_size (`int`, *optional*, defaults to 256): + Desired input size of the spectrogram that the model supports. It can be different from the output of the + `ClapFeatureExtractor`, in which case the input features will be resized. Corresponds to the `image_size` + of the audio models. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + patch_size (`int`, *optional*, defaults to 4): + Patch size for the audio spectrogram + patch_stride (`list`, *optional*, defaults to `[4, 4]`): + Patch stride for the audio spectrogram + num_classes (`int`, *optional*, defaults to 527): + Number of classes used for the head training + hidden_size (`int`, *optional*, defaults to 768): + Hidden size of the output of the audio encoder. Correspond to the dimension of the penultimate layer's + output,which is sent to the projection MLP layer. + projection_dim (`int`, *optional*, defaults to 512): + Hidden size of the projection layer. + depths (`list`, *optional*, defaults to `[2, 2, 6, 2]`): + Depths used for the Swin Layers of the audio model + num_attention_heads (`list`, *optional*, defaults to `[4, 8, 16, 32]`): + Number of attention heads used for the Swin Layers of the audio model + enable_fusion (`bool`, *optional*, defaults to `False`): + Whether or not to enable patch fusion. This is the main contribution of the authors, and should give the + best results. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the encoder. + fusion_type (`[type]`, *optional*): + Fusion type used for the patch fusion. + patch_embed_input_channels (`int`, *optional*, defaults to 1): + Number of channels used for the input spectrogram + flatten_patch_embeds (`bool`, *optional*, defaults to `True`): + Whether or not to flatten the patch embeddings + patch_embeds_hidden_size (`int`, *optional*, defaults to 96): + Hidden size of the patch embeddings. It is used as the number of output channels. + enable_patch_layer_norm (`bool`, *optional*, defaults to `True`): + Whether or not to enable layer normalization for the patch embeddings + drop_path_rate (`float`, *optional*, defaults to 0.0): + Drop path rate for the patch fusion + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to add a bias to the query, key, value projections. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of the mlp hidden dim to embedding dim. + aff_block_r (`int`, *optional*, defaults to 4): + downsize_ratio used in the AudioFF block + num_hidden_layers (`int`, *optional*, defaults to 4): + Number of hidden layers in the Transformer encoder. + projection_hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the projection layer. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + layer_norm_eps (`[type]`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import ClapAudioConfig, ClapAudioModel + + >>> # Initializing a ClapAudioConfig with laion/clap-htsat-fused style configuration + >>> configuration = ClapAudioConfig() + + >>> # Initializing a ClapAudioModel (with random weights) from the laion/clap-htsat-fused style configuration + >>> model = ClapAudioModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "clap_audio_model" + base_config_key = "audio_config" + + def __init__( + self, + window_size=8, + num_mel_bins=64, + spec_size=256, + hidden_act="gelu", + patch_size=4, + patch_stride=[4, 4], + num_classes=527, + hidden_size=768, + projection_dim=512, + depths=[2, 2, 6, 2], + num_attention_heads=[4, 8, 16, 32], + enable_fusion=False, + hidden_dropout_prob=0.1, + fusion_type=None, + patch_embed_input_channels=1, + flatten_patch_embeds=True, + patch_embeds_hidden_size=96, + enable_patch_layer_norm=True, + drop_path_rate=0.0, + attention_probs_dropout_prob=0.0, + qkv_bias=True, + mlp_ratio=4.0, + aff_block_r=4, + num_hidden_layers=4, + projection_hidden_act="relu", + layer_norm_eps=1e-5, + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + self.window_size = window_size + self.num_mel_bins = num_mel_bins + self.spec_size = spec_size + self.patch_size = patch_size + self.patch_stride = patch_stride + self.num_classes = num_classes + self.hidden_size = hidden_size + self.depths = depths + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.window_size = window_size + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.projection_dim = projection_dim + self.flatten_patch_embeds = flatten_patch_embeds + self.patch_embeds_hidden_size = patch_embeds_hidden_size + self.enable_patch_layer_norm = enable_patch_layer_norm + self.drop_path_rate = drop_path_rate + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.patch_embed_input_channels = patch_embed_input_channels + self.aff_block_r = aff_block_r + self.layer_norm_eps = layer_norm_eps + self.initializer_factor = initializer_factor + self.projection_hidden_act = projection_hidden_act + + +class ClapConfig(PretrainedConfig): + r""" + [`ClapConfig`] is the configuration class to store the configuration of a [`ClapModel`]. It is used to instantiate + a CLAP model according to the specified arguments, defining the text model and audio model configs. Instantiating a + configuration with the defaults will yield a similar configuration to that of the CLAP + [laion/clap-htsat-fused](https://huggingface.co/laion/clap-htsat-fused) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`ClapTextConfig`]. + audio_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`ClapAudioConfig`]. + logit_scale_init_value (`float`, *optional*, defaults to 14.29): + The initial value of the *logit_scale* parameter. Default is used as per the original CLAP implementation. + projection_dim (`int`, *optional*, defaults to 512): + Dimensionality of text and audio projection layers. + projection_hidden_act (`str`, *optional*, defaults to `"relu"`): + Activation function for the projection layers. + initializer_factor (`float`, *optional*, defaults to 1.0): + Factor to scale the initialization of the model weights. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import ClapConfig, ClapModel + + >>> # Initializing a ClapConfig with laion-ai/base style configuration + >>> configuration = ClapConfig() + + >>> # Initializing a ClapModel (with random weights) from the laion-ai/base style configuration + >>> model = ClapModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + + >>> # We can also initialize a ClapConfig from a ClapTextConfig and a ClapAudioConfig + >>> from transformers import ClapTextConfig, ClapAudioConfig + + >>> # Initializing a ClapText and ClapAudioConfig configuration + >>> config_text = ClapTextConfig() + >>> config_audio = ClapAudioConfig() + + >>> config = ClapConfig.from_text_audio_configs(config_text, config_audio) + ```""" + + model_type = "clap" + sub_configs = {"text_config": ClapTextConfig, "audio_config": ClapAudioConfig} + + def __init__( + self, + text_config=None, + audio_config=None, + logit_scale_init_value=(1 / 0.07), + projection_dim=512, + projection_hidden_act="relu", + initializer_factor=1.0, + **kwargs, + ): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("text_config is None. Initializing the ClapTextConfig with default values.") + + if audio_config is None: + audio_config = {} + logger.info("audio_config is None. initializing the ClapAudioConfig with default values.") + + self.text_config = ClapTextConfig(**text_config) + self.audio_config = ClapAudioConfig(**audio_config) + self.text_config.projection_dim = projection_dim + self.audio_config.projection_dim = projection_dim + + self.text_config.projection_hidden_act = projection_hidden_act + self.audio_config.projection_hidden_act = projection_hidden_act + + self.projection_dim = projection_dim + self.projection_hidden_act = projection_hidden_act + self.hidden_size = self.text_config.hidden_size + + self.logit_scale_init_value = logit_scale_init_value + self.initializer_factor = initializer_factor + self.num_hidden_layers = self.text_config.num_hidden_layers + len(self.audio_config.depths) + + +__all__ = ["ClapAudioConfig", "ClapConfig", "ClapTextConfig"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/clap/feature_extraction_clap.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/clap/feature_extraction_clap.py new file mode 100644 index 0000000000000000000000000000000000000000..33daac615c07ab879d8515e04e09d01fe27b37fb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/clap/feature_extraction_clap.py @@ -0,0 +1,367 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for CLAP.""" + +import copy +from typing import Any, Optional, Union + +import numpy as np +import torch + +from ...audio_utils import mel_filter_bank, spectrogram, window_function +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import TensorType, logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + + +@requires(backends=("torch",)) +class ClapFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a CLAP feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the *Short Time + Fourier Transform* (STFT) which should match pytorch's `torch.stft` equivalent. + + Args: + feature_size (`int`, *optional*, defaults to 64): + The feature dimension of the extracted Mel spectrograms. This corresponds to the number of mel filters + (`n_mels`). + sampling_rate (`int`, *optional*, defaults to 48000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). This only serves + to warn users if the audio fed to the feature extractor does not have the same sampling rate. + hop_length (`int`,*optional*, defaults to 480): + Length of the overlapping windows for the STFT used to obtain the Mel Spectrogram. The audio will be split + in smaller `frames` with a step of `hop_length` between each frame. + max_length_s (`int`, *optional*, defaults to 10): + The maximum input length of the model in seconds. This is used to pad the audio. + fft_window_size (`int`, *optional*, defaults to 1024): + Size of the window (in samples) on which the Fourier transform is applied. This controls the frequency + resolution of the spectrogram. 400 means that the fourier transform is computed on windows of 400 samples. + padding_value (`float`, *optional*, defaults to 0.0): + Padding value used to pad the audio. Should correspond to silences. + return_attention_mask (`bool`, *optional*, defaults to `False`): + Whether or not the model should return the attention masks corresponding to the input. + frequency_min (`float`, *optional*, defaults to 0): + The lowest frequency of interest. The STFT will not be computed for values below this. + frequency_max (`float`, *optional*, defaults to 14000): + The highest frequency of interest. The STFT will not be computed for values above this. + top_db (`float`, *optional*): + The highest decibel value used to convert the mel spectrogram to the log scale. For more details see the + `audio_utils.power_to_db` function + truncation (`str`, *optional*, defaults to `"fusion"`): + Truncation pattern for long audio inputs. Two patterns are available: + - `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and a + downsampled version of the entire mel spectrogram. + If `config.fusion` is set to True, shorter audios also need to to return 4 mels, which will just be a copy + of the original mel obtained from the padded audio. + - `rand_trunc` will select a random crop of the mel spectrogram. + padding (`str`, *optional*, defaults to `"repeatpad"`): + Padding pattern for shorter audio inputs. Three patterns were originally implemented: + - `repeatpad`: the audio is repeated, and then padded to fit the `max_length`. + - `repeat`: the audio is repeated and then cut to fit the `max_length` + - `pad`: the audio is padded. + """ + + model_input_names = ["input_features", "is_longer"] + + def __init__( + self, + feature_size=64, + sampling_rate=48_000, + hop_length=480, + max_length_s=10, + fft_window_size=1024, + padding_value=0.0, + return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask + frequency_min: float = 0, + frequency_max: float = 14_000, + top_db: Optional[int] = None, + truncation: str = "fusion", + padding: str = "repeatpad", + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + return_attention_mask=return_attention_mask, + **kwargs, + ) + self.top_db = top_db + self.truncation = truncation + self.padding = padding + self.fft_window_size = fft_window_size + self.nb_frequency_bins = (fft_window_size >> 1) + 1 + self.hop_length = hop_length + self.max_length_s = max_length_s + self.nb_max_samples = max_length_s * sampling_rate + self.sampling_rate = sampling_rate + self.frequency_min = frequency_min + self.frequency_max = frequency_max + self.mel_filters = mel_filter_bank( + num_frequency_bins=self.nb_frequency_bins, + num_mel_filters=feature_size, + min_frequency=frequency_min, + max_frequency=frequency_max, + sampling_rate=sampling_rate, + norm=None, + mel_scale="htk", + ) + self.mel_filters_slaney = mel_filter_bank( + num_frequency_bins=self.nb_frequency_bins, + num_mel_filters=feature_size, + min_frequency=frequency_min, + max_frequency=frequency_max, + sampling_rate=sampling_rate, + norm="slaney", + mel_scale="slaney", + ) + + def to_dict(self) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, except for the + mel filter banks, which do not need to be saved or printed as they are too long. + """ + output = copy.deepcopy(self.__dict__) + output["feature_extractor_type"] = self.__class__.__name__ + if "mel_filters" in output: + del output["mel_filters"] + if "mel_filters_slaney" in output: + del output["mel_filters_slaney"] + return output + + def _np_extract_fbank_features(self, waveform: np.ndarray, mel_filters: Optional[np.ndarray] = None) -> np.ndarray: + """ + Compute the log-mel spectrogram of the provided `waveform` using the Hann window. In CLAP, two different filter + banks are used depending on the truncation pattern: + - `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from + calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation` + is set to `"fusion"`. + - `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used + `librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original + implementation when the truncation mode is not `"fusion"`. + """ + log_mel_spectrogram = spectrogram( + waveform, + window_function(self.fft_window_size, "hann"), + frame_length=self.fft_window_size, + hop_length=self.hop_length, + power=2.0, + mel_filters=mel_filters, + log_mel="dB", + ) + return log_mel_spectrogram.T + + def _random_mel_fusion(self, mel, total_frames, chunk_frames): + ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) + if len(ranges[1]) == 0: + # if the audio is too short, we just use the first chunk + ranges[1] = [0] + if len(ranges[2]) == 0: + # if the audio is too short, we just use the first chunk + ranges[2] = [0] + # randomly choose index for each part + idx_front = np.random.choice(ranges[0]) + idx_middle = np.random.choice(ranges[1]) + idx_back = np.random.choice(ranges[2]) + + mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :] + mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :] + mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :] + + mel = torch.tensor(mel[None, None, :]) + mel_shrink = torch.nn.functional.interpolate( + mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False + ) + mel_shrink = mel_shrink[0][0].numpy() + mel_fusion = np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0) + return mel_fusion + + def _get_input_mel(self, waveform: np.ndarray, max_length, truncation, padding) -> np.ndarray: + """ + Extracts the mel spectrogram and prepares it for the mode based on the `truncation` and `padding` arguments. + Four different path are possible: + - `truncation="fusion"` and the length of the waveform is greater than the max length: the mel spectrogram + will be computed on the entire audio. 3 random crops and a dowsampled version of the full mel spectrogram + are then stacked together. They will later be used for `feature_fusion`. + - `truncation="rand_trunc"` and the length of the waveform is smaller than the max length: the audio is + padded based on `padding`. + - `truncation="fusion"` and the length of the waveform is smaller than the max length: the audio is padded + based on `padding`, and is repeated `4` times. + - `truncation="rand_trunc"` and the length of the waveform is greater than the max length: the mel + spectrogram will be computed on a random crop of the waveform. + + """ + if waveform.shape[0] > max_length: + if truncation == "rand_trunc": + longer = True + # random crop to max_length (for compatibility) -> this should be handled by self.pad + overflow = len(waveform) - max_length + idx = np.random.randint(0, overflow + 1) + waveform = waveform[idx : idx + max_length] + input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :] + elif truncation == "fusion": + mel = self._np_extract_fbank_features(waveform, self.mel_filters) + chunk_frames = max_length // self.hop_length + 1 # the +1 related to how the spectrogram is computed + total_frames = mel.shape[0] + if chunk_frames == total_frames: + # there is a corner case where the audio length is larger than max_length but smaller than max_length+hop_length. + # In this case, we just use the whole audio. + input_mel = np.stack([mel, mel, mel, mel], axis=0) + longer = False + else: + input_mel = self._random_mel_fusion(mel, total_frames, chunk_frames) + longer = True + else: + raise NotImplementedError(f"data_truncating {truncation} not implemented") + + else: + longer = False + # only use repeat as a new possible value for padding. you repeat the audio before applying the usual max_length padding + if waveform.shape[0] < max_length: + if padding == "repeat": + n_repeat = int(max_length / len(waveform)) + waveform = np.tile(waveform, n_repeat + 1)[:max_length] + if padding == "repeatpad": + n_repeat = int(max_length / len(waveform)) + waveform = np.tile(waveform, n_repeat) + waveform = np.pad(waveform, (0, max_length - waveform.shape[0]), mode="constant", constant_values=0) + + if truncation == "fusion": + input_mel = self._np_extract_fbank_features(waveform, self.mel_filters) + input_mel = np.stack([input_mel, input_mel, input_mel, input_mel], axis=0) + else: + input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :] + + return input_mel, longer + + def __call__( + self, + raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + truncation: Optional[str] = None, + padding: Optional[str] = None, + max_length: Optional[int] = None, + sampling_rate: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. + truncation (`str`, *optional*): + Truncation pattern for long audio inputs. Two patterns are available: + - `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and + a downsampled version of the entire mel spectrogram. + If `config.fusion` is set to True, shorter audios also need to to return 4 mels, which will just be a + copy of the original mel obtained from the padded audio. + - `rand_trunc` will select a random crop of the mel spectrogram. + padding (`str`, *optional*): + Padding pattern for shorter audio inputs. Three patterns were originally implemented: + - `repeatpad`: the audio is repeated, and then padded to fit the `max_length`. + - `repeat`: the audio is repeated and then cut to fit the `max_length` + - `pad`: the audio is padded. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.np.array` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition + pipeline. + """ + truncation = truncation if truncation is not None else self.truncation + padding = padding if padding else self.padding + + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" + f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" + f" was sampled with {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_speech = [np.asarray(speech, dtype=np.float64) for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float64) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float64) + + # always return batch + if not is_batched: + raw_speech = [np.asarray(raw_speech)] + + # convert to mel spectrogram, truncate and pad if needed. + padded_inputs = [ + self._get_input_mel(waveform, max_length if max_length else self.nb_max_samples, truncation, padding) + for waveform in raw_speech + ] + + input_mel = [] + is_longer = [] + for mel, longer in padded_inputs: + input_mel.append(mel) + is_longer.append(longer) + + if truncation == "fusion" and sum(is_longer) == 0: + # if no audio is longer than 10s, then randomly select one audio to be longer + rand_idx = np.random.randint(0, len(input_mel)) + is_longer[rand_idx] = True + + if isinstance(input_mel[0], list): + input_mel = [np.asarray(feature, dtype=np.float64) for feature in input_mel] + + # is_longer is a list of bool + is_longer = [[longer] for longer in is_longer] + + input_features = {"input_features": input_mel, "is_longer": is_longer} + input_features = BatchFeature(input_features) + + if return_tensors is not None: + input_features = input_features.convert_to_tensors(return_tensors) + + return input_features + + +__all__ = ["ClapFeatureExtractor"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/clap/modeling_clap.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/clap/modeling_clap.py new file mode 100644 index 0000000000000000000000000000000000000000..9d81a26581dd35dcfcc06b0f1881640acd47a070 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/clap/modeling_clap.py @@ -0,0 +1,1929 @@ +# coding=utf-8 +# Copyright 2023 The LAION-AI Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch CLAP model.""" + +import collections +import math +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...utils import ModelOutput, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs, logging, torch_int +from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig + + +logger = logging.get_logger(__name__) + + +# Adapted from: https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/utils.py#L191 +def interpolate(hidden_states, ratio): + """ + Interpolate data in time domain. This is used to compensate the resolution reduction in downsampling of a CNN. + + Args: + hidden_states (`torch.FloatTensor` of shape (batch_size, time_length, classes_num)): + Input hidden states + ratio (`int`): + The ratio of the length of the output to the length of the input. + """ + (batch_size, time_length, classes_num) = hidden_states.shape + upsampled = hidden_states[:, :, None, :].repeat(1, 1, ratio, 1) + upsampled = upsampled.reshape(batch_size, time_length * ratio, classes_num) + return upsampled + + +# Adapted from https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/htsat.py#L249 +def window_partition(hidden_states, window_size): + """ + Returns the resized hidden states. The output shape should be `(batch_size * num_windows, window_size, window_size, + num_channels)` + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, height, width, num_channels)`): + Input hidden states + window_size (`int`): + Window size + """ + batch_size, height, width, num_channels = hidden_states.shape + + hidden_states = hidden_states.view( + batch_size, height // window_size, window_size, width // window_size, window_size, num_channels + ) + windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows + + +# Adapted from https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/htsat.py#L263 +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + Args: + windows (`torch.FloatTensor` of shape `(num_windows * batch_size, window_size, window_size, num_channels)`): + Input windows + window_size (`int`): + Window size + height (`int`): + Height of the resized audio + width (`int`): + Width of the resized audio + """ + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) + return windows + + +# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +# contrastive loss function, adapted from +# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html#CLIP-loss-function +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + labels = torch.arange(len(logits), device=logits.device) + return nn.functional.cross_entropy(logits, labels) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for text model's outputs that also contains a pooling of the last hidden states. + """ +) +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Clap +class ClapTextModelOutput(ModelOutput): + r""" + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + ClapAudio model output to mimic the output of the original implementation. + """ +) +class ClapAudioModelOutput(ModelOutput): + r""" + audio_embeds (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + The Audio embeddings obtained by applying the projection layer to the pooler_output. + """ + + audio_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + attentions: Optional[tuple[torch.FloatTensor, ...]] = None + + +@dataclass +@auto_docstring +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Clap, vision->audio, Vision->Audio, image->audio +class ClapOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for audio-text similarity. + logits_per_audio (`torch.FloatTensor` of shape `(audio_batch_size, text_batch_size)`): + The scaled dot product scores between `audio_embeds` and `text_embeds`. This represents the audio-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, audio_batch_size)`): + The scaled dot product scores between `text_embeds` and `audio_embeds`. This represents the text-audio + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`ClapTextModel`]. + audio_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The audio embeddings obtained by applying the projection layer to the pooled output of [`ClapAudioModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`ClapTextModel`]. + audio_model_output (`BaseModelOutputWithPooling`): + The output of the [`ClapAudioModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_audio: Optional[torch.FloatTensor] = None + logits_per_text: Optional[torch.FloatTensor] = None + text_embeds: Optional[torch.FloatTensor] = None + audio_embeds: Optional[torch.FloatTensor] = None + text_model_output: BaseModelOutputWithPooling = None + audio_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "audio_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# Adapted from transformers.models.swin.modeling_swin.SwinDropPath +class ClapDropPath(nn.Module): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is a slightly + refactored version of the `SwinDropPath` implementation. + """ + + def __init__(self, drop_prob=None): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states): + if self.drop_prob == 0.0 or not self.training: + return hidden_states + + keep_prob = 1 - self.drop_prob + # work with diff dim tensors, not just 2D ConvNets + shape = (hidden_states.shape[0],) + (1,) * (hidden_states.ndim - 1) + + random_tensor = keep_prob + torch.rand(shape, dtype=hidden_states.dtype, device=hidden_states.device) + random_tensor.floor_() # binarize + output = hidden_states.div(keep_prob) * random_tensor + return output + + +# Adapted from https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/feature_fusion.py#L133 +class ClapAudioAFFBlock(nn.Module): + r""" + ATTENTIONAL FEATURE FUSION Block from CLAP, since in CLAP we are always in 2D mode, it is not needed to implement + the 1D version. + """ + + def __init__(self, config: ClapAudioConfig): + super().__init__() + channels = config.patch_embeds_hidden_size + downsize_ratio = config.aff_block_r + inter_channels = int(channels // downsize_ratio) + + self.local_att = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + self.sigmoid = nn.Sigmoid() + + def forward(self, hidden_states, residual): + attention_input = hidden_states + residual + + fused_layer_output = self.local_att(attention_input) + self.global_att(attention_input) + fused_layer_output = self.sigmoid(fused_layer_output) + + output = 2 * hidden_states * fused_layer_output + 2 * residual * (1 - fused_layer_output) + return output + + +class ClapAudioPatchEmbed(nn.Module): + """ + This module converts the hidden states reshaped as an image to patch embeddings ready to be passed to the + Transformer block. + """ + + def __init__(self, config: ClapAudioConfig): + super().__init__() + img_size = (config.spec_size, config.spec_size) if isinstance(config.spec_size, int) else config.spec_size + patch_size = ( + (config.patch_size, config.patch_size) if isinstance(config.patch_size, int) else config.patch_size + ) + patch_stride = ( + (config.patch_stride, config.patch_stride) if isinstance(config.patch_stride, int) else config.patch_stride + ) + + self.img_size = img_size + self.patch_stride = patch_stride + + self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + + self.flatten = config.flatten_patch_embeds + self.enable_fusion = config.enable_fusion + + padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2) + + scale_factor = 4 if (self.enable_fusion) and (config.fusion_type == "channel_map") else 1 + + self.proj = nn.Conv2d( + config.patch_embed_input_channels * scale_factor, + config.patch_embeds_hidden_size, + kernel_size=patch_size, + stride=patch_stride, + padding=padding, + ) + + self.norm = nn.LayerNorm(config.patch_embeds_hidden_size) if config.enable_patch_layer_norm else nn.Identity() + if self.enable_fusion: + self.fusion_model = ClapAudioAFFBlock(config) + self.mel_conv2d = nn.Conv2d( + config.patch_embed_input_channels, + config.patch_embeds_hidden_size, + kernel_size=(patch_size[0], patch_size[1] * 3), + stride=(patch_stride[0], patch_stride[1] * 3), + padding=padding, + ) + + def forward(self, hidden_states, is_longer_idx=None): + if self.enable_fusion: + # retrieve the last mel as we have transposed the input + global_hidden_states = hidden_states[:, 0:1, :, :] + + # global processing + batch_size, num_channels, height, width = global_hidden_states.shape + + if height != self.img_size[0] or width != self.img_size[1]: + raise ValueError( + f"Input audio size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + ) + + global_hidden_states = self.proj(global_hidden_states) + output_width = global_hidden_states.size(-1) + if len(is_longer_idx) > 0: + # local processing + local_hidden_states = hidden_states[is_longer_idx, 1:, :, :].contiguous() + batch_size, num_channels, height, width = local_hidden_states.shape + local_hidden_states = local_hidden_states.view(batch_size * num_channels, 1, height, width) + + local_hidden_states = self.mel_conv2d(local_hidden_states) + + _, features, height, width = local_hidden_states.shape + local_hidden_states = local_hidden_states.view(batch_size, num_channels, features, height, width) + local_hidden_states = local_hidden_states.permute((0, 2, 3, 1, 4)).contiguous().flatten(3) + + local_width = local_hidden_states.size(-1) + local_hidden_states = torch.nn.functional.pad( + local_hidden_states, (0, output_width - local_width), "constant", 0 + ) + + global_hidden_states[is_longer_idx] = self.fusion_model( + global_hidden_states[is_longer_idx], local_hidden_states + ) + hidden_states = global_hidden_states + else: + _, _, height, width = hidden_states.shape + if height != self.img_size[0] or width != self.img_size[1]: + raise ValueError( + f"Input audio size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + ) + hidden_states = self.proj(hidden_states) + + if self.flatten: + hidden_states = hidden_states.flatten(2).transpose(1, 2) + hidden_states = self.norm(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->ClapAudio +class ClapAudioSelfAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) + ) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape + hidden_shape = (batch_size, dim, -1, self.attention_head_size) + + query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) + + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attention_scores = attention_scores + relative_position_bias.unsqueeze(0) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in ClapAudioModel forward() function) + mask_shape = attention_mask.shape[0] + attention_scores = attention_scores.view( + batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim + ) + attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0) + attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->ClapAudio +class ClapAudioSelfOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->ClapAudio +class ClapAudioAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size): + super().__init__() + self.self = ClapAudioSelfAttention(config, dim, num_heads, window_size) + self.output = ClapAudioSelfOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->ClapAudio +class ClapAudioIntermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->ClapAudio +class ClapAudioOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinLayer with SwinDropPath->ClapDropPath, Swin->ClapAudio +class ClapAudioLayer(nn.Module): + def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.shift_size = shift_size + self.window_size = config.window_size + self.input_resolution = input_resolution + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.attention = ClapAudioAttention(config, dim, num_heads, window_size=self.window_size) + self.drop_path = ClapDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.intermediate = ClapAudioIntermediate(config, dim) + self.output = ClapAudioOutput(config, dim) + + def set_shift_and_window_size(self, input_resolution): + if min(input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = torch_int(0) + self.window_size = ( + torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) + ) + + def get_attn_mask(self, height, width, dtype, device): + if self.shift_size > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0) + else: + attn_mask = None + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_right = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + pad_values = (0, 0, 0, pad_right, 0, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + if not always_partition: + self.set_shift_and_window_size(input_dimensions) + else: + pass + height, width = input_dimensions + batch_size, _, channels = hidden_states.size() + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states) + + hidden_states = hidden_states.view(batch_size, height, width, channels) + + # pad hidden_states to multiples of window size + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + + _, height_pad, width_pad, _ = hidden_states.shape + # cyclic shift + if self.shift_size > 0: + shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) + hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask( + height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device + ) + + attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) + shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) + + # reverse cyclic shift + if self.shift_size > 0: + attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :].contiguous() + + attention_windows = attention_windows.view(batch_size, height * width, channels) + + hidden_states = shortcut + self.drop_path(attention_windows) + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = hidden_states + self.output(layer_output) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->ClapAudio +class ClapAudioStage(GradientCheckpointingLayer): + def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList( + [ + ClapAudioLayer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[i], + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging with Swin->ClapAudio +class ClapAudioPatchMerging(nn.Module): + """ + Patch Merging Layer. + + Args: + input_resolution (`tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, input_resolution: tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def maybe_pad(self, input_feature, height, width): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature: torch.Tensor, input_dimensions: tuple[int, int]) -> torch.Tensor: + height, width = input_dimensions + # `dim` is height * width + batch_size, dim, num_channels = input_feature.shape + + input_feature = input_feature.view(batch_size, height, width, num_channels) + # pad input to be divisible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # batch_size height/2 width/2 4*num_channels + input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C + + input_feature = self.norm(input_feature) + input_feature = self.reduction(input_feature) + + return input_feature + + +class ClapAudioEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.num_layers = len(config.depths) + + self.config = config + self.patch_embed = ClapAudioPatchEmbed(config) + self.enable_fusion = config.enable_fusion + self.patch_stride = self.patch_embed.patch_stride + self.spec_size = config.spec_size + self.freq_ratio = config.spec_size // config.num_mel_bins + + self.num_features = int(config.patch_embeds_hidden_size * 2 ** (self.num_layers - 1)) + + drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] + + grid_size = self.patch_embed.grid_size + self.input_resolutions = [(grid_size[0] // (2**i), grid_size[1] // (2**i)) for i in range(self.num_layers)] + + self.layers = nn.ModuleList( + [ + ClapAudioStage( + config=config, + dim=int(config.patch_embeds_hidden_size * 2**i_layer), + input_resolution=self.input_resolutions[i_layer], + depth=config.depths[i_layer], + num_heads=config.num_attention_heads[i_layer], + drop_path=drop_path_rate[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=ClapAudioPatchMerging if (i_layer < self.num_layers - 1) else None, + ) + for i_layer in range(self.num_layers) + ] + ) + + self.gradient_checkpointing = False + + self.batch_norm = nn.BatchNorm2d(config.num_mel_bins) + self.norm = nn.LayerNorm(self.num_features) + self.depths = config.depths + self.avgpool = nn.AdaptiveAvgPool1d(1) + + def reshape_mel2img(self, normalized_input_features): + """ + The input is 4 normalized log mel spectrograms. It is reshape to the common shape of images. Each channel + should represent 1 of the 4 crops of the spectrogram. For more details, refer to the [`ClapFeatureExtractor`]. + """ + _, _, time_length, freq_length = normalized_input_features.shape + + spec_width = int(self.spec_size * self.freq_ratio) + spec_height = self.spec_size // self.freq_ratio + + if time_length > spec_width or freq_length > spec_height: + raise ValueError("the wav size should be less than or equal to the swin input size") + + # to avoid bicubic zero error + if time_length < spec_width: + normalized_input_features = nn.functional.interpolate( + normalized_input_features, (spec_width, freq_length), mode="bicubic", align_corners=True + ) + if freq_length < spec_height: + normalized_input_features = nn.functional.interpolate( + normalized_input_features, (time_length, spec_height), mode="bicubic", align_corners=True + ) + + batch, channels, time, freq = normalized_input_features.shape + + # batch_size, channels, spec_width, spec_height --> batch_size, channels, spec_height * freq_ratio, spec_width // freq_ratio + normalized_input_features = normalized_input_features.reshape( + batch, channels * self.freq_ratio, time // self.freq_ratio, freq + ) + normalized_input_features = normalized_input_features.permute(0, 1, 3, 2).contiguous() + normalized_input_features = normalized_input_features.reshape( + batch, channels, freq * self.freq_ratio, time // self.freq_ratio + ) + + return normalized_input_features + + def forward( + self, + input_features, + is_longer: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + output_hidden_states_before_downsampling: Optional[bool] = False, + always_partition: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[tuple, ClapAudioModelOutput]: + input_features = input_features.transpose(1, 3) + normalized_input_features = self.batch_norm(input_features) + normalized_input_features = normalized_input_features.transpose(1, 3) + + is_longer_list_idx = None + if self.enable_fusion: + is_longer_list = is_longer.to(input_features.device) + is_longer_list_idx = torch.where(is_longer_list == 1)[0] + + hidden_states = self.reshape_mel2img(normalized_input_features) + + frames_num = hidden_states.shape[2] + + hidden_states = self.patch_embed(hidden_states, is_longer_list_idx) + + all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + input_dimensions = self.input_resolutions[0] + + if output_hidden_states: + batch_size, _, hidden_size = hidden_states.shape + # rearrange batch_size (height width) channels -> batch_size channel height width + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + input_dimensions = self.input_resolutions[i] + + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = layer_outputs[1] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + if output_hidden_states and output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states_before_downsampling.shape + # rearrange batch_size (height width) channels -> batch_size channel height width + # here we use the original (not downsampled) height and width + reshaped_hidden_state = hidden_states_before_downsampling.view( + batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size + ) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states_before_downsampling,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + elif output_hidden_states and not output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states.shape + # rearrange batch_size (height width) channels -> batch_size channel height width + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[3:] + + last_hidden_state = self.norm(hidden_states) + + batch_size, _, n_channels = last_hidden_state.shape + + freq_shape = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] + temporal_shape = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] + + last_hidden_state = ( + last_hidden_state.permute(0, 2, 1).contiguous().reshape(batch_size, n_channels, freq_shape, temporal_shape) + ) + + batch_size, n_channels, n_frequencies, n_temp = last_hidden_state.shape + # group 2D CNN + c_freq_bin = n_frequencies // self.freq_ratio + last_hidden_state = last_hidden_state.reshape( + batch_size, n_channels, n_frequencies // c_freq_bin, c_freq_bin, n_temp + ) + last_hidden_state = ( + last_hidden_state.permute(0, 1, 3, 2, 4).contiguous().reshape(batch_size, n_channels, c_freq_bin, -1) + ) + latent_output = self.avgpool(torch.flatten(last_hidden_state, 2)) + latent_output = torch.flatten(latent_output, 1) + + if not return_dict: + return tuple( + v + for v in [ + last_hidden_state, + latent_output, + all_reshaped_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=latent_output, + hidden_states=all_reshaped_hidden_states, + attentions=all_self_attentions, + ) + + +class ClapProjectionLayer(nn.Module): + def __init__(self, config: Union[ClapAudioConfig, ClapTextConfig]): + super().__init__() + self.config = config + hidden_size = config.hidden_size + projection_dim = config.projection_dim + + self.linear1 = nn.Linear(hidden_size, projection_dim) + self.activation = ACT2FN[config.projection_hidden_act] + self.linear2 = nn.Linear(projection_dim, projection_dim) + + def forward(self, hidden_states): + hidden_states = self.linear1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.linear2(hidden_states) + return hidden_states + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->ClapText, persistent=False->persistent=True +class ClapTextEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=True + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=True + ) + + # End copy + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +# Copied from transformers.models.align.modeling_align.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->Clap +class ClapTextSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.attention_dropout = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.attention_head_size) + + query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + head_mask=head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class ClapTextSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->Clap +class ClapTextAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = ClapTextSelfAttention(config) + self.output = ClapTextSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + **kwargs, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class ClapTextIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput +class ClapTextOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->Clap +class ClapTextLayer(GradientCheckpointingLayer): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ClapTextAttention(config) + self.intermediate = ClapTextIntermediate(config) + self.output = ClapTextOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[torch.Tensor]: + self_attention_outputs = self.attention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + **kwargs, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->Clap +class ClapTextEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([ClapTextLayer(config) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + @can_return_tuple + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + **kwargs, + ) -> Union[tuple[torch.Tensor], BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, + output_attentions=output_attentions, + **kwargs, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class ClapTextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@auto_docstring +class ClapPreTrainedModel(PreTrainedModel): + config: ClapConfig + base_model_prefix = "clap" + supports_gradient_checkpointing = False + + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + factor = self.config.initializer_factor + + if isinstance(module, ClapTextEmbeddings): + module.position_embeddings.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.token_type_embeddings.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, ClapModel): + module.logit_scale_a.data.fill_(math.log(self.config.logit_scale_init_value)) + module.logit_scale_t.data.fill_(math.log(self.config.logit_scale_init_value)) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, (nn.Conv2d, nn.Linear)): + in_proj_std = (self.config.hidden_size**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5) * factor + nn.init.normal_(module.weight, std=in_proj_std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, ClapAudioSelfAttention): + module.relative_position_bias_table.data.zero_() + + +class ClapAudioModel(ClapPreTrainedModel): + config: ClapAudioConfig + main_input_name = "input_features" + + def __init__(self, config: ClapAudioConfig): + super().__init__(config) + self.audio_encoder = ClapAudioEncoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.audio_encoder.patch_embed.proj + + @auto_docstring + def forward( + self, + input_features: Optional[torch.FloatTensor] = None, + is_longer: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + r""" + is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*): + Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance + the features. + + Examples: + + ```python + >>> from datasets import load_dataset + >>> from transformers import AutoProcessor, ClapAudioModel + + >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example") + >>> audio_sample = dataset["train"]["audio"][0]["array"] + + >>> model = ClapAudioModel.from_pretrained("laion/clap-htsat-fused") + >>> processor = AutoProcessor.from_pretrained("laion/clap-htsat-fused") + + >>> inputs = processor(audios=audio_sample, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + return self.audio_encoder( + input_features=input_features, + is_longer=is_longer, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@auto_docstring( + custom_intro=""" + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in *Attention is + all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz + Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + + .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762 + """ +) +class ClapTextModel(ClapPreTrainedModel): + config: ClapTextConfig + + def __init__(self, config, add_pooling_layer=True): + r""" + add_pooling_layer (bool, *optional*, defaults to `True`): + Whether to add a pooling layer + """ + super().__init__(config) + self.config = config + + self.embeddings = ClapTextEmbeddings(config) + self.encoder = ClapTextEncoder(config) + + self.pooler = ClapTextPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring +class ClapModel(ClapPreTrainedModel): + config: ClapConfig + + def __init__(self, config: ClapConfig): + super().__init__(config) + + if not isinstance(config.text_config, ClapTextConfig): + raise TypeError( + "config.text_config is expected to be of type ClapTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.audio_config, ClapAudioConfig): + raise TypeError( + "config.audio_config is expected to be of type ClapAudioConfig but is of type" + f" {type(config.audio_config)}." + ) + + text_config = config.text_config + audio_config = config.audio_config + + self.logit_scale_a = nn.Parameter(torch.tensor(math.log(config.logit_scale_init_value))) + self.logit_scale_t = nn.Parameter(torch.tensor(math.log(config.logit_scale_init_value))) + + self.projection_dim = config.projection_dim + + self.text_model = ClapTextModel(text_config) + self.text_projection = ClapProjectionLayer(text_config) + + self.audio_model = ClapAudioModel(audio_config) + self.audio_projection = ClapProjectionLayer(audio_config) + + # Initialize weights and apply final processing + self.post_init() + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_text_features( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`ClapTextModel`]. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, ClapModel + + >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused") + >>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused") + + >>> inputs = tokenizer(["the sound of a cat", "the sound of a dog"], padding=True, return_tensors="pt") + >>> with torch.inference_mode(): + ... text_features = model.get_text_features(**inputs) + ```""" + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids + ) + text_features = self.text_projection(text_outputs.pooler_output) + text_features = F.normalize(text_features, dim=-1) + + return text_features + + @filter_out_non_signature_kwargs() + @auto_docstring + def get_audio_features( + self, + input_features: torch.Tensor, + is_longer: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r""" + is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*): + Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance + the features. + + Returns: + audio_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The audio embeddings obtained by + applying the projection layer to the pooled output of [`ClapAudioModel`]. + + Examples: + + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, ClapModel + + >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused") + >>> random_audio = torch.rand((16_000)) + + >>> inputs = feature_extractor(random_audio, return_tensors="pt") + >>> with torch.inference_mode(): + ... audio_features = model.get_audio_features(**inputs) + ```""" + audio_outputs: BaseModelOutputWithPooling = self.audio_model( + input_features=input_features, is_longer=is_longer + ) + audio_features = self.audio_projection(audio_outputs.pooler_output) + audio_features = F.normalize(audio_features, dim=-1) + + return audio_features + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + is_longer: Optional[torch.BoolTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ClapOutput]: + r""" + is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*): + Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance + the features. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + + Examples: + + ```python + >>> from datasets import load_dataset + >>> from transformers import AutoProcessor, ClapModel + + >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example") + >>> audio_sample = dataset["train"]["audio"][0]["array"] + + >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused") + >>> processor = AutoProcessor.from_pretrained("laion/clap-htsat-unfused") + + >>> input_text = ["Sound of a dog", "Sound of vacuum cleaner"] + + >>> inputs = processor(text=input_text, audios=audio_sample, return_tensors="pt", padding=True) + + >>> outputs = model(**inputs) + >>> logits_per_audio = outputs.logits_per_audio # this is the audio-text similarity score + >>> probs = logits_per_audio.softmax(dim=-1) # we can take the softmax to get the label probabilities + ```""" + # Use CLAP model's config for some fields (if specified) instead of those of audio & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + audio_outputs = self.audio_model( + input_features=input_features, + is_longer=is_longer, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + audio_embeds = audio_outputs[1] if not return_dict else audio_outputs.pooler_output + audio_embeds = self.audio_projection(audio_embeds) + + text_embeds = text_outputs[1] if not return_dict else text_outputs.pooler_output + text_embeds = self.text_projection(text_embeds) + + # normalized features + audio_embeds = audio_embeds / audio_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale_text = self.logit_scale_t.exp() + logit_scale_audio = self.logit_scale_a.exp() + logits_per_text = torch.matmul(text_embeds, audio_embeds.t()) * logit_scale_text + logits_per_audio = torch.matmul(audio_embeds, text_embeds.t()) * logit_scale_audio + + loss = None + if return_loss: + caption_loss = contrastive_loss(logits_per_text) + audio_loss = contrastive_loss(logits_per_audio.t()) + loss = (caption_loss + audio_loss) / 2.0 + + return ClapOutput( + loss=loss, + logits_per_audio=logits_per_audio, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + audio_embeds=audio_embeds, + text_model_output=text_outputs, + audio_model_output=audio_outputs, + ) + + +@auto_docstring +class ClapTextModelWithProjection(ClapPreTrainedModel): + config: ClapTextConfig + + def __init__(self, config: ClapTextConfig): + super().__init__(config) + self.text_model = ClapTextModel(config) + self.text_projection = ClapProjectionLayer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.text_model.embeddings.word_embeddings = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ClapTextModelOutput]: + r""" + Examples: + + ```python + >>> from transformers import AutoTokenizer, ClapTextModelWithProjection + + >>> model = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused") + >>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused") + + >>> inputs = tokenizer(["a sound of a cat", "a sound of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> text_embeds = outputs.text_embeds + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + pooled_output = text_outputs[1] if not return_dict else text_outputs.pooler_output + + text_embeds = self.text_projection(pooled_output) + + return ClapTextModelOutput( + text_embeds=text_embeds, + last_hidden_state=text_outputs.last_hidden_state, + hidden_states=text_outputs.hidden_states, + attentions=text_outputs.attentions, + ) + + +@auto_docstring +class ClapAudioModelWithProjection(ClapPreTrainedModel): + config: ClapAudioConfig + main_input_name = "input_features" + + def __init__(self, config: ClapAudioConfig): + super().__init__(config) + self.audio_model = ClapAudioModel(config) + self.audio_projection = ClapProjectionLayer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.audio_model.audio_encoder.patch_embed.proj + + @can_return_tuple + @auto_docstring + def forward( + self, + input_features: Optional[torch.FloatTensor] = None, + is_longer: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ClapAudioModelOutput]: + r""" + is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*): + Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance + the features. + + Examples: + + ```python + >>> from datasets import load_dataset + >>> from transformers import ClapAudioModelWithProjection, ClapProcessor + + >>> model = ClapAudioModelWithProjection.from_pretrained("laion/clap-htsat-fused") + >>> processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused") + + >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example") + >>> audio_sample = dataset["train"]["audio"][0]["array"] + + >>> inputs = processor(audios=audio_sample, return_tensors="pt") + >>> outputs = model(**inputs) + >>> audio_embeds = outputs.audio_embeds + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + audio_outputs = self.audio_model( + input_features=input_features, + is_longer=is_longer, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + pooled_output = audio_outputs[1] if not return_dict else audio_outputs.pooler_output + + audio_embeds = self.audio_projection(pooled_output) + + return ClapAudioModelOutput( + audio_embeds=audio_embeds, + last_hidden_state=audio_outputs.last_hidden_state, + attentions=audio_outputs.attentions, + hidden_states=audio_outputs.hidden_states, + ) + + +__all__ = [ + "ClapModel", + "ClapPreTrainedModel", + "ClapTextModel", + "ClapTextModelWithProjection", + "ClapAudioModel", + "ClapAudioModelWithProjection", +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/clap/processing_clap.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/clap/processing_clap.py new file mode 100644 index 0000000000000000000000000000000000000000..6524a87158418206b8b96a7b57f6c1b7392e56cf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/clap/processing_clap.py @@ -0,0 +1,75 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Audio/Text processor class for CLAP +""" + +from typing import Optional, Union + +from ...audio_utils import AudioInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging +from ...utils.deprecation import deprecate_kwarg + + +logger = logging.get_logger(__name__) + + +class ClapProcessor(ProcessorMixin): + r""" + Constructs a CLAP processor which wraps a CLAP feature extractor and a RoBerta tokenizer into a single processor. + + [`ClapProcessor`] offers all the functionalities of [`ClapFeatureExtractor`] and [`RobertaTokenizerFast`]. See the + [`~ClapProcessor.__call__`] and [`~ClapProcessor.decode`] for more information. + + Args: + feature_extractor ([`ClapFeatureExtractor`]): + The audio processor is a required input. + tokenizer ([`RobertaTokenizerFast`]): + The tokenizer is a required input. + """ + + feature_extractor_class = "ClapFeatureExtractor" + tokenizer_class = ("RobertaTokenizer", "RobertaTokenizerFast") + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + + @deprecate_kwarg("audios", version="v4.59.0", new_name="audio") + def __call__( + self, + text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None, + audios: Optional[AudioInput] = None, + audio: Optional[AudioInput] = None, + **kwargs: Unpack[ProcessingKwargs], + ): + """ + Forwards the `audio` and `sampling_rate` arguments to [`~ClapFeatureExtractor.__call__`] and the `text` + argument to [`~RobertaTokenizerFast.__call__`]. Please refer to the docstring of the above two methods for more + information. + """ + # The `deprecate_kwarg` will not work if the inputs are passed as arguments, so we check + # again that the correct naming is used + if audios is not None and audio is None: + logger.warning( + "Using `audios` keyword argument is deprecated when calling ClapProcessor, instead use `audio`." + ) + audio = audios + + return super().__call__(text=text, audio=audio, **kwargs) + + +__all__ = ["ClapProcessor"] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/code_llama/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/code_llama/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b65c4bddb4b0cd3fa8dfd6a781a3c0f58e30e5a7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/code_llama/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .tokenization_code_llama import * + from .tokenization_code_llama_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/code_llama/tokenization_code_llama.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/code_llama/tokenization_code_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..94d1b4d659851a5f466c8123dd032ab213a90ed4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/code_llama/tokenization_code_llama.py @@ -0,0 +1,454 @@ +# coding=utf-8 +# Copyright 2023 MetaAI and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for Code LLaMA.""" + +import os +from shutil import copyfile +from typing import Any, Optional + +import sentencepiece as spm + +from ...convert_slow_tokenizer import import_protobuf +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging, requires_backends +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +SPIECE_UNDERLINE = "▁" + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" +# fmt: on + + +@requires(backends=("sentencepiece",)) +class CodeLlamaTokenizer(PreTrainedTokenizer): + """ + Construct a CodeLlama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as + there is no padding token in the original model. + + The default configuration match that of + [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/meta-llama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json) + which supports prompt infilling. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + prefix_token (`str`, *optional*, defaults to `"▁
"`):
+            Prefix token used for infilling.
+        middle_token (`str`, *optional*, defaults to `"▁"`):
+            Middle token used for infilling.
+        suffix_token (`str`, *optional*, defaults to `"▁"`):
+            Suffix token used for infilling.
+        eot_token (`str`, *optional*, defaults to `"▁"`):
+            End of text token used for infilling.
+        fill_token (`str`, *optional*, defaults to `""`):
+            The token used to split the input between the prefix and suffix.
+        suffix_first (`bool`, *optional*, defaults to `False`):
+            Whether the input prompt and suffix should be formatted with the suffix first.
+        sp_model_kwargs (`dict`, *optional*):
+            Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+            SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+            to set:
+
+            - `enable_sampling`: Enable subword regularization.
+            - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+              - `nbest_size = {0,1}`: No sampling is performed.
+              - `nbest_size > 1`: samples from the nbest_size results.
+              - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+                using forward-filtering-and-backward-sampling algorithm.
+
+            - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+              BPE-dropout.
+        add_bos_token (`bool`, *optional*, defaults to `True`):
+            Whether to add a beginning of sequence token at the start of sequences.
+        add_eos_token (`bool`, *optional*, defaults to `False`):
+            Whether to add an end of sequence token at the end of sequences.
+        clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+            Whether or not to clean up the tokenization spaces.
+        additional_special_tokens (`list[str]`, *optional*):
+            Additional special tokens used by the tokenizer.
+        use_default_system_prompt (`bool`, *optional*, defaults to `False`):
+            Whether or not the default system prompt for Llama should be used.
+    """
+
+    vocab_files_names = VOCAB_FILES_NAMES
+    model_input_names = ["input_ids", "attention_mask"]
+
+    def __init__(
+        self,
+        vocab_file,
+        unk_token="",
+        bos_token="",
+        eos_token="",
+        prefix_token="▁
",
+        middle_token="▁",
+        suffix_token="▁",
+        eot_token="▁",
+        fill_token="",
+        suffix_first=False,
+        sp_model_kwargs: Optional[dict[str, Any]] = None,
+        add_bos_token=True,
+        add_eos_token=False,
+        clean_up_tokenization_spaces=False,
+        additional_special_tokens=None,
+        use_default_system_prompt=False,
+        **kwargs,
+    ):
+        requires_backends(self, "protobuf")
+        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+        bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
+        eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
+        unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
+
+        self.use_default_system_prompt = use_default_system_prompt
+        # mark tokens special to skip them
+        additional_special_tokens = additional_special_tokens or []
+        for token in [prefix_token, middle_token, suffix_token, eot_token]:
+            additional_special_tokens += [token] if token is not None else []
+
+        self.vocab_file = vocab_file
+        self.add_bos_token = add_bos_token
+        self.add_eos_token = add_eos_token
+        self._prefix_token = prefix_token
+        self._middle_token = middle_token
+        self._suffix_token = suffix_token
+        self._eot_token = eot_token
+        self.fill_token = fill_token
+        self.suffix_first = suffix_first
+        self.sp_model = self.get_spm_processor()
+
+        super().__init__(
+            bos_token=bos_token,
+            eos_token=eos_token,
+            unk_token=unk_token,
+            add_bos_token=add_bos_token,
+            add_eos_token=add_eos_token,
+            prefix_token=prefix_token,
+            middle_token=middle_token,
+            suffix_token=suffix_token,
+            eot_token=eot_token,
+            fill_token=fill_token,
+            sp_model_kwargs=self.sp_model_kwargs,
+            suffix_first=suffix_first,
+            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+            additional_special_tokens=additional_special_tokens,
+            use_default_system_prompt=use_default_system_prompt,
+            **kwargs,
+        )
+
+    @property
+    def unk_token_length(self):
+        return len(self.sp_model.encode(str(self.unk_token)))
+
+    def get_spm_processor(self):
+        tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+        with open(self.vocab_file, "rb") as f:
+            sp_model = f.read()
+            model_pb2 = import_protobuf()
+            model = model_pb2.ModelProto.FromString(sp_model)
+            normalizer_spec = model_pb2.NormalizerSpec()
+            normalizer_spec.add_dummy_prefix = False
+            model.normalizer_spec.MergeFrom(normalizer_spec)
+            sp_model = model.SerializeToString()
+            tokenizer.LoadFromSerializedProto(sp_model)
+        return tokenizer
+
+    @property
+    def prefix_token(self):
+        return self._prefix_token
+
+    @property
+    def prefix_id(self):
+        if self._prefix_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.prefix_token)
+
+    @property
+    def middle_token(self):
+        return self._middle_token
+
+    @property
+    def middle_id(self):
+        if self._middle_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.middle_token)
+
+    @property
+    def suffix_token(self):
+        return self._suffix_token
+
+    @property
+    def suffix_id(self):
+        if self._suffix_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.suffix_token)
+
+    @property
+    def eot_token(self):
+        return self._eot_token
+
+    @property
+    def eot_id(self):
+        if self._eot_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.eot_token)
+
+    @property
+    def vocab_size(self):
+        """Returns vocab size"""
+        return self.sp_model.get_piece_size()
+
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab
+    def get_vocab(self):
+        """Returns vocab as a dict"""
+        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+        vocab.update(self.added_tokens_encoder)
+        return vocab
+
+    def tokenize(self, prefix, suffix=None, suffix_first=False, **kwargs) -> list[int]:
+        # add a prefix space to `prefix`
+        if self.fill_token is not None and self.fill_token in prefix and suffix is None:
+            prefix, suffix = prefix.split(self.fill_token)
+
+        if len(prefix) > 0:
+            prefix = SPIECE_UNDERLINE + prefix.replace(SPIECE_UNDERLINE, " ")
+
+        if suffix is None or len(suffix) < 1:
+            tokens = super().tokenize(prefix, **kwargs)
+            if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
+                tokens = tokens[1:]
+            return tokens
+
+        prefix_tokens = self._tokenize(prefix)  # prefix has an extra `SPIECE_UNDERLINE`
+
+        if None in (self.prefix_id, self.middle_id, self.suffix_id):
+            raise ValueError(
+                "The input either includes a `prefix` and a `suffix` used for the infilling task,"
+                f"  or can be split on the {self.fill_token} token, creating a suffix and prefix,"
+                " but the model does not support `infilling`."
+            )
+        suffix_tokens = self._tokenize(suffix)  # make sure CodeLlama sp model does not mess up
+
+        suffix_first = suffix_first if suffix_first is not None else self.suffix_first
+        if suffix_first:
+            # format as " 
 {suf}  {pre}"
+            return [self.prefix_token, self.suffix_token] + suffix_tokens + [self.middle_token] + prefix_tokens
+        else:
+            # format as " 
 {pre} {suf} "
+            return [self.prefix_token] + prefix_tokens + [self.suffix_token] + suffix_tokens + [self.middle_token]
+
+    def _tokenize(self, text, **kwargs):
+        """
+        Returns a tokenized string.
+
+        We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
+        SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
+        `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
+        `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`.
+        `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`.
+        """
+        tokens = self.sp_model.encode(text, out_type=str)
+        if not text.startswith((SPIECE_UNDERLINE, " ")):
+            return tokens
+        # 1. Encode string + prefix ex: " Hey"
+        tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
+        # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
+        return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
+
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_token_to_id
+    def _convert_token_to_id(self, token):
+        """Converts a token (str) in an id using the vocab."""
+        return self.sp_model.piece_to_id(token)
+
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_id_to_token
+    def _convert_id_to_token(self, index):
+        """Converts an index (integer) in a token (str) using the vocab."""
+        token = self.sp_model.IdToPiece(index)
+        return token
+
+    def convert_tokens_to_string(self, tokens):
+        """Converts a sequence of tokens (string) in a single string."""
+        # since we manually add the prefix space, we have to remove it when decoding
+        if tokens[0].startswith(SPIECE_UNDERLINE):
+            tokens[0] = tokens[0][1:]
+
+        current_sub_tokens = []
+        out_string = ""
+        for _, token in enumerate(tokens):
+            # make sure that special tokens are not decoded using sentencepiece model
+            if token in self.all_special_tokens:
+                out_string += self.sp_model.decode(current_sub_tokens) + token
+                current_sub_tokens = []
+            else:
+                current_sub_tokens.append(token)
+        out_string += self.sp_model.decode(current_sub_tokens)
+        return out_string
+
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.save_vocabulary
+    def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> tuple[str]:
+        """
+        Save the vocabulary and special tokens file to a directory.
+
+        Args:
+            save_directory (`str`):
+                The directory in which to save the vocabulary.
+
+        Returns:
+            `Tuple(str)`: Paths to the files saved.
+        """
+        if not os.path.isdir(save_directory):
+            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+            return
+        out_vocab_file = os.path.join(
+            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+        )
+
+        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+            copyfile(self.vocab_file, out_vocab_file)
+        elif not os.path.isfile(self.vocab_file):
+            with open(out_vocab_file, "wb") as fi:
+                content_spiece_model = self.sp_model.serialized_model_proto()
+                fi.write(content_spiece_model)
+
+        return (out_vocab_file,)
+
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
+    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+        bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+        eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+        output = bos_token_id + token_ids_0 + eos_token_id
+
+        if token_ids_1 is not None:
+            output = output + bos_token_id + token_ids_1 + eos_token_id
+
+        return output
+
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
+    def get_special_tokens_mask(
+        self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+    ) -> list[int]:
+        """
+        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+        special tokens using the tokenizer `prepare_for_model` method.
+
+        Args:
+            token_ids_0 (`list[int]`):
+                List of IDs.
+            token_ids_1 (`list[int]`, *optional*):
+                Optional second list of IDs for sequence pairs.
+            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not the token list is already formatted with special tokens for the model.
+
+        Returns:
+            `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+        """
+        if already_has_special_tokens:
+            return super().get_special_tokens_mask(
+                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+            )
+
+        bos_token_id = [1] if self.add_bos_token else []
+        eos_token_id = [1] if self.add_eos_token else []
+
+        if token_ids_1 is None:
+            return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
+        return (
+            bos_token_id
+            + ([0] * len(token_ids_0))
+            + eos_token_id
+            + bos_token_id
+            + ([0] * len(token_ids_1))
+            + eos_token_id
+        )
+
+    # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
+    def create_token_type_ids_from_sequences(
+        self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+    ) -> list[int]:
+        """
+        Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
+        sequence pair mask has the following format:
+
+        ```
+        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+        | first sequence    | second sequence |
+        ```
+
+        if token_ids_1 is None, only returns the first portion of the mask (0s).
+
+        Args:
+            token_ids_0 (`list[int]`):
+                List of ids.
+            token_ids_1 (`list[int]`, *optional*):
+                Optional second list of IDs for sequence pairs.
+
+        Returns:
+            `list[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+        """
+        bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+        eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+        output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
+
+        if token_ids_1 is not None:
+            output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
+
+        return output
+
+    def __getstate__(self):
+        state = self.__dict__.copy()
+        state["sp_model"] = None
+        state["sp_model_proto"] = self.sp_model.serialized_model_proto()
+        return state
+
+    def __setstate__(self, d):
+        self.__dict__ = d
+        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+        self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
+
+
+__all__ = ["CodeLlamaTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3978587e7f02512a5344f9ad0a33bf86b839757
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py
@@ -0,0 +1,374 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from shutil import copyfile
+from typing import Optional
+
+from tokenizers import normalizers, processors
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import is_sentencepiece_available, logging
+
+
+if is_sentencepiece_available():
+    from .tokenization_code_llama import CodeLlamaTokenizer
+else:
+    CodeLlamaTokenizer = None
+
+logger = logging.get_logger(__name__)
+VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
+
+SPIECE_UNDERLINE = "▁"
+
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<>\n", "\n<>\n\n"
+
+# fmt: off
+DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
+answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
+ that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
+correct. If you don't know the answer to a question, please don't share false information."""
+# fmt: on
+
+
+class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
+    """
+    Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+    This uses notably ByteFallback and no normalization.
+
+    ```python
+    >>> from transformers import CodeLlamaTokenizerFast
+
+    >>> tokenizer = CodeLlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
+    >>> tokenizer.encode("Hello this is a test")
+    [1, 15043, 445, 338, 263, 1243]
+    ```
+
+    If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
+    call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
+    values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
+    [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
+
+
+    This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+    refer to this superclass for more information regarding those methods. The default configuration match that of
+    [meta-llama/CodeLlama-7b-Instruct-hf](https://huggingface.co/meta-llama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json)
+    which supports prompt infilling.
+
+    Args:
+        vocab_file (`str`, *optional*):
+            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
+            contains the vocabulary necessary to instantiate a tokenizer.
+        tokenizer_file (`str`, *optional*):
+            [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
+            contains everything needed to load the tokenizer.
+        clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
+            Whether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
+            spaces.
+        unk_token (`str`, *optional*, defaults to `""`):
+            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+            token instead.
+        bos_token (`str`, *optional*, defaults to `""`):
+            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+        eos_token (`str`, *optional*, defaults to `""`):
+            The end of sequence token.
+        prefix_token (`str`, *optional*, defaults to `"▁
"`):
+            Prefix token used for infilling.
+        middle_token (`str`, *optional*, defaults to `"▁"`):
+            Middle token used for infilling.
+        suffix_token (`str`, *optional*, defaults to `"▁"`):
+            Suffix token used for infilling.
+        eot_token (`str`, *optional*, defaults to `"▁"`):
+            End of text token used for infilling.
+        fill_token (`str`, *optional*, defaults to `""`):
+            The token used to split the input between the prefix and suffix.
+        additional_special_tokens (`list[str]`, *optional*):
+            Additional special tokens used by the tokenizer.
+        add_bos_token (`bool`, *optional*, defaults to `True`):
+            Whether to add a beginning of sequence token at the start of sequences.
+        add_eos_token (`bool`, *optional*, defaults to `False`):
+            Whether to add an end of sequence token at the end of sequences.
+        use_default_system_prompt (`bool`, *optional*, defaults to `False`):
+            Whether or not the default system prompt for Llama should be used.
+    """
+
+    vocab_files_names = VOCAB_FILES_NAMES
+    slow_tokenizer_class = CodeLlamaTokenizer
+    padding_side = "left"
+    model_input_names = ["input_ids", "attention_mask"]
+
+    def __init__(
+        self,
+        vocab_file=None,
+        tokenizer_file=None,
+        clean_up_tokenization_spaces=False,
+        unk_token="",
+        bos_token="",
+        eos_token="",
+        prefix_token="▁
",
+        middle_token="▁",
+        suffix_token="▁",
+        eot_token="▁",
+        fill_token="",
+        additional_special_tokens=None,
+        add_bos_token=True,
+        add_eos_token=False,
+        use_default_system_prompt=False,
+        **kwargs,
+    ):
+        # mark tokens special to skip them
+        additional_special_tokens = additional_special_tokens or []
+        for token in [prefix_token, middle_token, suffix_token, eot_token]:
+            additional_special_tokens += [token] if token is not None else []
+        self.use_default_system_prompt = use_default_system_prompt
+
+        super().__init__(
+            vocab_file=vocab_file,
+            tokenizer_file=tokenizer_file,
+            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+            additional_special_tokens=additional_special_tokens,
+            unk_token=unk_token,
+            bos_token=bos_token,
+            eos_token=eos_token,
+            add_bos_token=add_bos_token,
+            add_eos_token=add_eos_token,
+            prefix_token=prefix_token,
+            middle_token=middle_token,
+            suffix_token=suffix_token,
+            eot_token=eot_token,
+            fill_token=fill_token,
+            use_default_system_prompt=use_default_system_prompt,
+            **kwargs,
+        )
+        self._add_bos_token = add_bos_token
+        self._add_eos_token = add_eos_token
+        self.update_post_processor()
+
+        self.vocab_file = vocab_file
+
+        self._prefix_token = prefix_token
+        self._middle_token = middle_token
+        self._suffix_token = suffix_token
+        self._eot_token = eot_token
+        self.fill_token = fill_token
+
+    # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.update_post_processor
+    def update_post_processor(self):
+        """
+        Updates the underlying post processor with the current `bos_token` and `eos_token`.
+        """
+        bos = self.bos_token
+        bos_token_id = self.bos_token_id
+        if bos is None and self.add_bos_token:
+            raise ValueError("add_bos_token = True but bos_token = None")
+
+        eos = self.eos_token
+        eos_token_id = self.eos_token_id
+        if eos is None and self.add_eos_token:
+            raise ValueError("add_eos_token = True but eos_token = None")
+
+        single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
+        pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}"
+
+        special_tokens = []
+        if self.add_bos_token:
+            special_tokens.append((bos, bos_token_id))
+        if self.add_eos_token:
+            special_tokens.append((eos, eos_token_id))
+        self._tokenizer.post_processor = processors.TemplateProcessing(
+            single=single, pair=pair, special_tokens=special_tokens
+        )
+
+    @property
+    def prefix_token(self):
+        return self._prefix_token
+
+    @property
+    def prefix_id(self):
+        if self._prefix_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.prefix_token)
+
+    @property
+    def middle_token(self):
+        return self._middle_token
+
+    @property
+    def middle_id(self):
+        if self._middle_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.middle_token)
+
+    @property
+    def suffix_token(self):
+        return self._suffix_token
+
+    @property
+    def suffix_id(self):
+        if self._suffix_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.suffix_token)
+
+    @property
+    def eot_id(self):
+        if self._eot_token is None:
+            return None
+        return self.convert_tokens_to_ids(self.eot_token)
+
+    @property
+    def eot_token(self):
+        return self._eot_token
+
+    @property
+    def add_eos_token(self):
+        return self._add_eos_token
+
+    @property
+    def add_bos_token(self):
+        return self._add_bos_token
+
+    @add_eos_token.setter
+    def add_eos_token(self, value):
+        self._add_eos_token = value
+        self.update_post_processor()
+
+    @add_bos_token.setter
+    def add_bos_token(self, value):
+        self._add_bos_token = value
+        self.update_post_processor()
+
+    def set_infilling_processor(self, reset, suffix_first=False, add_special_tokens=True):
+        """
+        Updates the normalizer to make sure the prompt format for `infilling` is respected. The infilling format is the
+        following: if suffix_first
+            " 
 {suf}  {pre}"
+        else:
+            " 
 {pre} {suf} "
+
+        If `reset` is set to `True`, the `normalizer` and `post_processor` are reset to their "normal" behaviour, which
+        is to add a prefix space for the normalizer, and add a `bos_token` to the input text for the `post_processor`.
+        """
+        if reset:
+            self._tokenizer.normalizer = normalizers.Sequence(
+                [
+                    normalizers.Prepend(prepend="▁"),
+                    normalizers.Replace(pattern=" ", content="▁"),
+                ]
+            )
+            self.update_post_processor()
+            return
+
+        self._tokenizer.normalizer = normalizers.Replace(pattern=" ", content="▁")
+        pair = [self.bos_token] if self.add_bos_token and add_special_tokens else []
+        special_tokens = [(self.bos_token, self.bos_token_id)] if self.add_bos_token and add_special_tokens else []
+        if suffix_first:
+            # format as " 
 {suf}  {pre}"
+            pair += [self.prefix_token, self.suffix_token, "$B", self.middle_token, "$A"]
+            special_tokens += [
+                (self.prefix_token, self.prefix_id),
+                (self.suffix_token, self.suffix_id),
+                (self.middle_token, self.middle_id),
+            ]
+        else:
+            # format as " 
 {pre} {suf} "
+            pair += [self.prefix_token, "$A", self.suffix_token, "$B", self.middle_token]
+            special_tokens += [
+                (self.prefix_token, self.prefix_id),
+                (self.suffix_token, self.suffix_id),
+                (self.middle_token, self.middle_id),
+            ]
+
+        if self.add_eos_token and add_special_tokens:
+            pair += [self.eos_token]
+            special_tokens += [(self.eos_token, self.eos_token_id)]
+        self._tokenizer.post_processor = processors.TemplateProcessing(
+            single="$A", pair=pair, special_tokens=special_tokens
+        )
+
+    def encode_plus(self, text, text_pair=None, suffix_first=False, add_special_tokens=True, **kwargs):
+        # hack to make sure the input is pre-process but outside rust
+        text_pair = kwargs.pop("suffix", text_pair)
+        if self.fill_token is not None and self.fill_token in text and text_pair is None:
+            text, text_pair = text.split(self.fill_token)
+
+        if text_pair is None or len(text_pair) < 1:
+            return super().encode_plus(text, text_pair, add_special_tokens=add_special_tokens, **kwargs)
+
+        if None in (self.prefix_id, self.middle_id, self.suffix_id):
+            raise ValueError(
+                "Then input includes a `prefix` and a `suffix` used for the infilling task,"
+                " the `prefix_id, middle_id, suffix_id` must all be initialized. Current"
+                f" values : {self.prefix_id, self.middle_id, self.suffix_id}"
+            )
+
+        self.set_infilling_processor(False, suffix_first=suffix_first, add_special_tokens=add_special_tokens)
+        tokens = super().encode_plus(" " + text, text_pair=text_pair, add_special_tokens=True, **kwargs)
+        self.set_infilling_processor(True)
+        return tokens
+
+    # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.save_vocabulary
+    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+        if not self.can_save_slow_tokenizer:
+            raise ValueError(
+                "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+                "tokenizer."
+            )
+
+        if not os.path.isdir(save_directory):
+            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+            return
+        out_vocab_file = os.path.join(
+            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+        )
+
+        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+            copyfile(self.vocab_file, out_vocab_file)
+
+        return (out_vocab_file,)
+
+    def build_inputs_with_special_tokens(
+        self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+    ) -> list[int]:
+        """
+        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+        adding special tokens. The special tokens depend on calling set_lang.
+
+        An NLLB sequence has the following format, where `X` represents the sequence:
+
+        - `input_ids` (for encoder) `X [eos, src_lang_code]`
+        - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
+
+        BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
+        separator.
+
+        Args:
+            token_ids_0 (`list[int]`):
+                List of IDs to which the special tokens will be added.
+            token_ids_1 (`list[int]`, *optional*):
+                Optional second list of IDs for sequence pairs.
+
+        Returns:
+            `list[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+        """
+        if token_ids_1 is None:
+            return self.bos_token_id + token_ids_0 + self.eos_token_id
+        return self.bos_token_id + token_ids_0 + token_ids_1 + self.eos_token_id
+
+
+__all__ = ["CodeLlamaTokenizerFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea2d9af11150f556b2cedfb78271f174256e64b0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_codegen import *
+    from .modeling_codegen import *
+    from .tokenization_codegen import *
+    from .tokenization_codegen_fast import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/configuration_codegen.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/configuration_codegen.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a9ab842710cdd67911305ce324fe4c68dec173b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/configuration_codegen.py
@@ -0,0 +1,231 @@
+# coding=utf-8
+# Copyright 2022 Salesforce authors, The EleutherAI, and HuggingFace Teams. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""CodeGen model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+from typing import Any, Optional
+
+from ... import PreTrainedTokenizer, TensorType, is_torch_available
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfigWithPast, PatchingSpec
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class CodeGenConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`CodeGenModel`]. It is used to instantiate a
+    CodeGen model according to the specified arguments, defining the model architecture. Instantiating a configuration
+    with the defaults will yield a similar configuration to that of the CodeGen
+    [Salesforce/codegen-2B-mono](https://huggingface.co/Salesforce/codegen-2B-mono) architecture. Configuration objects
+    inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from
+    [`PretrainedConfig`] for more information.
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 50400):
+            Vocabulary size of the CodeGen model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`CodeGenModel`].
+        n_positions (`int`, *optional*, defaults to 2048):
+            The maximum sequence length that this model might ever be used with. Typically set this to something large
+            just in case (e.g., 512 or 1024 or 2048).
+        n_ctx (`int`, *optional*, defaults to 2048):
+            This attribute is used in `CodeGenModel.__init__` without any real effect.
+        n_embd (`int`, *optional*, defaults to 4096):
+            Dimensionality of the embeddings and hidden states.
+        n_layer (`int`, *optional*, defaults to 28):
+            Number of hidden layers in the Transformer encoder.
+        n_head (`int`, *optional*, defaults to 16):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        rotary_dim (`int`, *optional*, defaults to 64):
+            Number of dimensions in the embedding that Rotary Position Embedding is applied to.
+        n_inner (`int`, *optional*):
+            Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
+        activation_function (`str`, *optional*, defaults to `"gelu_new"`):
+            Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
+        resid_pdrop (`float`, *optional*, defaults to 0.0):
+            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+        embd_pdrop (`int`, *optional*, defaults to 0.0):
+            The dropout ratio for the embeddings.
+        attn_pdrop (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention.
+        layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
+            The epsilon to use in the layer normalization layers.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models).
+        bos_token_id (`int`, *optional*, defaults to 50256):
+            Beginning of stream token id.
+        eos_token_id (`int`, *optional*, defaults to 50256):
+            End of stream token id.
+        tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+            Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
+            model has a output word embedding layer.
+
+    Example:
+
+    ```python
+    >>> from transformers import CodeGenConfig, CodeGenModel
+
+    >>> # Initializing a CodeGen 6B configuration
+    >>> configuration = CodeGenConfig()
+
+    >>> # Initializing a model (with random weights) from the configuration
+    >>> model = CodeGenModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "codegen"
+    attribute_map = {
+        "max_position_embeddings": "n_positions",
+        "hidden_size": "n_embd",
+        "num_attention_heads": "n_head",
+        "num_hidden_layers": "n_layer",
+    }
+
+    def __init__(
+        self,
+        vocab_size=50400,
+        n_positions=2048,
+        n_ctx=2048,
+        n_embd=4096,
+        n_layer=28,
+        n_head=16,
+        rotary_dim=64,
+        n_inner=None,
+        activation_function="gelu_new",
+        resid_pdrop=0.0,
+        embd_pdrop=0.0,
+        attn_pdrop=0.0,
+        layer_norm_epsilon=1e-5,
+        initializer_range=0.02,
+        use_cache=True,
+        bos_token_id=50256,
+        eos_token_id=50256,
+        tie_word_embeddings=False,
+        **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.n_ctx = n_ctx
+        self.n_positions = n_positions
+        self.n_embd = n_embd
+        self.n_layer = n_layer
+        self.n_head = n_head
+        self.n_inner = n_inner
+        self.rotary_dim = rotary_dim
+        self.activation_function = activation_function
+        self.resid_pdrop = resid_pdrop
+        self.embd_pdrop = embd_pdrop
+        self.attn_pdrop = attn_pdrop
+        self.layer_norm_epsilon = layer_norm_epsilon
+        self.initializer_range = initializer_range
+        self.use_cache = use_cache
+
+        self.bos_token_id = bos_token_id
+        self.eos_token_id = eos_token_id
+
+        super().__init__(
+            bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
+        )
+
+
+# Copied from transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig
+class CodeGenOnnxConfig(OnnxConfigWithPast):
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        task: str = "default",
+        patching_specs: Optional[list[PatchingSpec]] = None,
+        use_past: bool = False,
+    ):
+        super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
+        if not getattr(self._config, "pad_token_id", None):
+            # TODO: how to do that better?
+            self._config.pad_token_id = 0
+
+    @property
+    def inputs(self) -> Mapping[str, Mapping[int, str]]:
+        common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
+        if self.use_past:
+            self.fill_with_past_key_values_(common_inputs, direction="inputs")
+            common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
+        else:
+            common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
+
+        return common_inputs
+
+    @property
+    def num_layers(self) -> int:
+        return self._config.n_layer
+
+    @property
+    def num_attention_heads(self) -> int:
+        return self._config.n_head
+
+    def generate_dummy_inputs(
+        self,
+        tokenizer: PreTrainedTokenizer,
+        batch_size: int = -1,
+        seq_length: int = -1,
+        is_pair: bool = False,
+        framework: Optional[TensorType] = None,
+    ) -> Mapping[str, Any]:
+        common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
+            tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
+        )
+
+        # We need to order the input in the way they appears in the forward()
+        ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
+
+        # Need to add the past_keys
+        if self.use_past:
+            if not is_torch_available():
+                raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
+            else:
+                import torch
+
+                batch, seqlen = common_inputs["input_ids"].shape
+                # Not using the same length for past_key_values
+                past_key_values_length = seqlen + 2
+                past_shape = (
+                    batch,
+                    self.num_attention_heads,
+                    past_key_values_length,
+                    self._config.hidden_size // self.num_attention_heads,
+                )
+                ordered_inputs["past_key_values"] = [
+                    (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
+                ]
+
+        ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
+        if self.use_past:
+            mask_dtype = ordered_inputs["attention_mask"].dtype
+            ordered_inputs["attention_mask"] = torch.cat(
+                [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
+            )
+
+        return ordered_inputs
+
+    @property
+    def default_onnx_opset(self) -> int:
+        return 13
+
+
+__all__ = ["CodeGenConfig", "CodeGenOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/modeling_codegen.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/modeling_codegen.py
new file mode 100644
index 0000000000000000000000000000000000000000..887b400b479929c947ed2dfeadc436294033ad8a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/modeling_codegen.py
@@ -0,0 +1,668 @@
+# coding=utf-8
+# Copyright 2022 Salesforce authors, The EleutherAI, and HuggingFace Teams. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch CodeGen model."""
+
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+    auto_docstring,
+    is_torch_flex_attn_available,
+    logging,
+)
+from .configuration_codegen import CodeGenConfig
+
+
+if is_torch_flex_attn_available():
+    from torch.nn.attention.flex_attention import BlockMask
+
+    from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+# Copied from transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
+def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
+    inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
+    sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
+    return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
+
+
+# Copied from transformers.models.gptj.modeling_gptj.rotate_every_two
+def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
+    x1 = x[:, :, :, ::2]
+    x2 = x[:, :, :, 1::2]
+    x = torch.stack((-x2, x1), dim=-1)
+    return x.flatten(-2)  # in einsum notation: rearrange(x, '... d j -> ... (d j)')
+
+
+# Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb
+def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
+    sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
+    cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
+    return (tensor * cos) + (rotate_every_two(tensor) * sin)
+
+
+class CodeGenAttention(nn.Module):
+    def __init__(self, config, layer_idx=None):
+        super().__init__()
+
+        max_positions = config.max_position_embeddings
+        self.attn_dropout = nn.Dropout(config.attn_pdrop)
+        self.resid_dropout = nn.Dropout(config.resid_pdrop)
+        self.layer_idx = layer_idx
+        if layer_idx is None:
+            logger.warning_once(
+                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+                "when creating this class."
+            )
+
+        self.embed_dim = config.hidden_size
+        self.num_attention_heads = config.num_attention_heads
+        self.head_dim = self.embed_dim // self.num_attention_heads
+        if self.head_dim * self.num_attention_heads != self.embed_dim:
+            raise ValueError(
+                f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
+                f" `num_attention_heads`: {self.num_attention_heads})."
+            )
+        self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
+        self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
+
+        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
+        self.rotary_dim = config.rotary_dim
+        pos_embd_dim = self.rotary_dim or self.embed_dim
+        self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
+
+    def _split_heads(self, x, n_head, dim_head, mp_num):
+        reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
+        reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
+        return reshaped
+
+    def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
+        """
+        Merges attn_head_size dim and num_attn_heads dim into n_ctx
+        """
+        if len(tensor.shape) == 5:
+            tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
+        elif len(tensor.shape) == 4:
+            tensor = tensor.permute(0, 2, 1, 3).contiguous()
+        else:
+            raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
+        new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
+        return tensor.view(new_shape)
+
+    def _attn(
+        self,
+        query,
+        key,
+        value,
+        attention_mask=None,
+        head_mask=None,
+    ):
+        # Keep the attention weights computation in fp32 to avoid overflow issues
+        query = query.to(torch.float32)
+        key = key.to(torch.float32)
+
+        attn_weights = torch.matmul(query, key.transpose(-1, -2))
+
+        if attention_mask is not None:
+            causal_mask = attention_mask[:, :, :, : key.shape[-2]]
+            attn_weights += causal_mask
+
+        attn_weights = attn_weights / self.scale_attn
+        attn_weights = nn.Softmax(dim=-1)(attn_weights)
+        attn_weights = attn_weights.to(value.dtype)
+        attn_weights = self.attn_dropout(attn_weights)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attn_weights = attn_weights * head_mask
+
+        attn_output = torch.matmul(attn_weights, value)
+
+        return attn_output, attn_weights
+
+    def forward(
+        self,
+        hidden_states: Optional[torch.FloatTensor],
+        layer_past: Optional[Cache] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = False,
+        output_attentions: Optional[bool] = False,
+        cache_position: Optional[torch.LongTensor] = None,
+    ) -> Union[
+        tuple[torch.Tensor, tuple[torch.Tensor]],
+        Optional[tuple[torch.Tensor, tuple[torch.Tensor], tuple[torch.Tensor, ...]]],
+    ]:
+        qkv = self.qkv_proj(hidden_states)
+        # TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
+        mp_num = 4
+        qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
+
+        local_dim = self.head_dim * self.num_attention_heads // mp_num
+        query, value, key = torch.split(qkv_split, local_dim, dim=-1)
+        query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
+        key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
+
+        value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
+        value = value.permute(0, 2, 1, 3)
+
+        embed_positions = self.embed_positions
+        if embed_positions.device != position_ids.device:
+            embed_positions = embed_positions.to(position_ids.device)
+            self.embed_positions = embed_positions
+
+        sincos = embed_positions[position_ids]
+        sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
+
+        if self.rotary_dim is not None:
+            k_rot = key[:, :, :, : self.rotary_dim]
+            k_pass = key[:, :, :, self.rotary_dim :]
+
+            q_rot = query[:, :, :, : self.rotary_dim]
+            q_pass = query[:, :, :, self.rotary_dim :]
+
+            k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
+            q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
+
+            key = torch.cat([k_rot, k_pass], dim=-1)
+            query = torch.cat([q_rot, q_pass], dim=-1)
+        else:
+            key = apply_rotary_pos_emb(key, sin, cos)
+            query = apply_rotary_pos_emb(query, sin, cos)
+
+        key = key.permute(0, 2, 1, 3)
+        query = query.permute(0, 2, 1, 3)
+
+        # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
+        # Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
+        if layer_past is not None:
+            cache_kwargs = {
+                "sin": sin,
+                "cos": cos,
+                "partial_rotation_size": self.rotary_dim,
+                "cache_position": cache_position,
+            }
+            key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs)
+
+        # compute self-attention: V x Softmax(QK^T)
+        attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
+
+        attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
+        attn_output = self.out_proj(attn_output)
+        attn_output = self.resid_dropout(attn_output)
+        return attn_output, attn_weights
+
+
+# Copied from transformers.models.gptj.modeling_gptj.GPTJMLP with GPTJ->CodeGen
+class CodeGenMLP(nn.Module):
+    def __init__(self, intermediate_size, config):  # in MLP: intermediate_size= 4 * embed_dim
+        super().__init__()
+        embed_dim = config.n_embd
+
+        self.fc_in = nn.Linear(embed_dim, intermediate_size)
+        self.fc_out = nn.Linear(intermediate_size, embed_dim)
+
+        self.act = ACT2FN[config.activation_function]
+        self.dropout = nn.Dropout(config.resid_pdrop)
+
+    def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor:
+        hidden_states = self.fc_in(hidden_states)
+        hidden_states = self.act(hidden_states)
+        hidden_states = self.fc_out(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        return hidden_states
+
+
+# Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen
+class CodeGenBlock(GradientCheckpointingLayer):
+    # Ignore copy
+    def __init__(self, config, layer_idx=None):
+        super().__init__()
+        inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
+        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
+        self.attn = CodeGenAttention(config, layer_idx)
+        self.mlp = CodeGenMLP(inner_dim, config)
+
+    def forward(
+        self,
+        hidden_states: Optional[torch.FloatTensor],
+        layer_past: Optional[Cache] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = False,
+        output_attentions: Optional[bool] = False,
+        cache_position: Optional[torch.LongTensor] = None,
+    ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]:
+        residual = hidden_states
+        hidden_states = self.ln_1(hidden_states)
+        attn_outputs, attn_weights = self.attn(
+            hidden_states=hidden_states,
+            layer_past=layer_past,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            cache_position=cache_position,
+        )
+        feed_forward_hidden_states = self.mlp(hidden_states)
+        hidden_states = attn_outputs + feed_forward_hidden_states + residual
+
+        return hidden_states, attn_weights
+
+
+@auto_docstring
+class CodeGenPreTrainedModel(PreTrainedModel):
+    config: CodeGenConfig
+    base_model_prefix = "transformer"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["CodeGenBlock"]
+    _skip_keys_device_placement = "past_key_values"
+
+    _can_compile_fullgraph = True
+
+    def __init__(self, *inputs, **kwargs):
+        super().__init__(*inputs, **kwargs)
+
+    def _init_weights(self, module):
+        """Initialize the weights."""
+        if isinstance(module, (nn.Linear,)):
+            # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class CodeGenModel(CodeGenPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.embed_dim = config.n_embd
+        self.vocab_size = config.vocab_size
+        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
+        self.drop = nn.Dropout(config.embd_pdrop)
+        self.h = nn.ModuleList([CodeGenBlock(config, layer_idx=i) for i in range(config.n_layer)])
+        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+        self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
+
+        self.gradient_checkpointing = False
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.wte
+
+    def set_input_embeddings(self, new_embeddings):
+        self.wte = new_embeddings
+
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        token_type_ids: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,  # NOOP kwargs, for now
+    ) -> Union[tuple, BaseModelOutputWithPast]:
+        r"""
+        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
+            model's internal embedding lookup matrix.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if (input_ids is None) ^ (inputs_embeds is not None):
+            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        if inputs_embeds is None:
+            inputs_embeds = self.wte(input_ids)
+
+        if use_cache and past_key_values is None:
+            past_key_values = DynamicCache(config=self.config)
+
+        seq_length = inputs_embeds.shape[1]
+        if cache_position is None:
+            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+            cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
+
+        if position_ids is None:
+            position_ids = cache_position.unsqueeze(0)
+
+        causal_mask = self._update_causal_mask(
+            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+        )
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x num_attention_heads x N x N
+        # head_mask has shape n_layer x batch x num_attention_heads x N x N
+        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+        hidden_states = inputs_embeds
+
+        if token_type_ids is not None:
+            token_type_ids = token_type_ids.view(-1, seq_length)
+            token_type_embeds = self.wte(token_type_ids)
+            hidden_states = hidden_states + token_type_embeds
+
+        hidden_states = self.drop(hidden_states)
+        output_shape = (-1, seq_length, hidden_states.size(-1))
+
+        all_self_attentions = () if output_attentions else None
+        all_hidden_states = () if output_hidden_states else None
+        for i, block in enumerate(self.h):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            outputs = block(
+                hidden_states,
+                layer_past=past_key_values,
+                attention_mask=causal_mask,
+                position_ids=position_ids,
+                head_mask=head_mask[i],
+                use_cache=use_cache,
+                output_attentions=output_attentions,
+                cache_position=cache_position,
+            )
+
+            hidden_states = outputs[0]
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (outputs[1],)
+
+        hidden_states = self.ln_f(hidden_states)
+
+        hidden_states = hidden_states.view(output_shape)
+        # Add last hidden state
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(
+                v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
+            )
+
+        return BaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=past_key_values,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+        )
+
+    # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+    def _update_causal_mask(
+        self,
+        attention_mask: Union[torch.Tensor, "BlockMask"],
+        input_tensor: torch.Tensor,
+        cache_position: torch.Tensor,
+        past_key_values: Cache,
+        output_attentions: bool = False,
+    ):
+        if self.config._attn_implementation == "flash_attention_2":
+            if attention_mask is not None and (attention_mask == 0.0).any():
+                return attention_mask
+            return None
+        if self.config._attn_implementation == "flex_attention":
+            if isinstance(attention_mask, torch.Tensor):
+                attention_mask = make_flex_block_causal_mask(attention_mask)
+            return attention_mask
+
+        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+        # to infer the attention mask.
+        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+        using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+        if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+            if AttentionMaskConverter._ignore_causal_mask_sdpa(
+                attention_mask,
+                inputs_embeds=input_tensor,
+                past_key_values_length=past_seen_tokens,
+                is_training=self.training,
+            ):
+                return None
+
+        dtype = input_tensor.dtype
+        sequence_length = input_tensor.shape[1]
+        if using_compilable_cache:
+            target_length = past_key_values.get_max_cache_shape()
+        else:
+            target_length = (
+                attention_mask.shape[-1]
+                if isinstance(attention_mask, torch.Tensor)
+                else past_seen_tokens + sequence_length + 1
+            )
+
+        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+            attention_mask,
+            sequence_length=sequence_length,
+            target_length=target_length,
+            dtype=dtype,
+            cache_position=cache_position,
+            batch_size=input_tensor.shape[0],
+        )
+
+        if (
+            self.config._attn_implementation == "sdpa"
+            and attention_mask is not None
+            and attention_mask.device.type in ["cuda", "xpu", "npu"]
+            and not output_attentions
+        ):
+            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+            # Details: https://github.com/pytorch/pytorch/issues/110213
+            min_dtype = torch.finfo(dtype).min
+            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+        return causal_mask
+
+    @staticmethod
+    # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+    def _prepare_4d_causal_attention_mask_with_cache_position(
+        attention_mask: torch.Tensor,
+        sequence_length: int,
+        target_length: int,
+        dtype: torch.dtype,
+        cache_position: torch.Tensor,
+        batch_size: int,
+        **kwargs,
+    ):
+        """
+        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+        Args:
+            attention_mask (`torch.Tensor`):
+                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+                `(batch_size, 1, query_length, key_value_length)`.
+            sequence_length (`int`):
+                The sequence length being processed.
+            target_length (`int`):
+                The target length: when generating with static cache, the mask should be as long as the static cache,
+                to account for the 0 padding, the part of the cache that is not filled yet.
+            dtype (`torch.dtype`):
+                The dtype to use for the 4D attention mask.
+            cache_position (`torch.Tensor`):
+                Indices depicting the position of the input sequence tokens in the sequence.
+            batch_size (`torch.Tensor`):
+                Batch size.
+        """
+        if attention_mask is not None and attention_mask.dim() == 4:
+            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+            causal_mask = attention_mask
+        else:
+            min_dtype = torch.finfo(dtype).min
+            causal_mask = torch.full(
+                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+            )
+            if sequence_length != 1:
+                causal_mask = torch.triu(causal_mask, diagonal=1)
+            causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+            if attention_mask is not None:
+                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
+                mask_length = attention_mask.shape[-1]
+                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+                    causal_mask.device
+                )
+                padding_mask = padding_mask == 0
+                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+                    padding_mask, min_dtype
+                )
+
+        return causal_mask
+
+
+@auto_docstring(
+    custom_intro="""
+    The CodeGen Model transformer with a language modeling head on top.
+    """
+)
+class CodeGenForCausalLM(CodeGenPreTrainedModel, GenerationMixin):
+    _tied_weights_keys = ["lm_head.weight"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.transformer = CodeGenModel(config)
+        self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        token_type_ids: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ) -> Union[tuple, CausalLMOutputWithPast]:
+        r"""
+        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
+            model's internal embedding lookup matrix.
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        transformer_outputs = self.transformer(
+            input_ids,
+            past_key_values=past_key_values,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            cache_position=cache_position,
+        )
+        hidden_states = transformer_outputs[0]
+
+        # make sure sampling in fp16 works correctly and
+        # compute loss in fp32 to match with mesh-tf version
+        # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
+        lm_logits = self.lm_head(hidden_states).to(torch.float32)
+
+        loss = None
+        if labels is not None:
+            # move labels to correct device to enable model parallelism
+            labels = labels.to(lm_logits.device)
+            # Flatten the tokens
+            loss = self.loss_function(
+                lm_logits,
+                labels,
+                vocab_size=self.config.vocab_size,
+                **kwargs,
+            )
+
+            loss = loss.to(hidden_states.dtype)
+
+        if not return_dict:
+            output = (lm_logits,) + transformer_outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return CausalLMOutputWithPast(
+            loss=loss,
+            logits=lm_logits,
+            past_key_values=transformer_outputs.past_key_values,
+            hidden_states=transformer_outputs.hidden_states,
+            attentions=transformer_outputs.attentions,
+        )
+
+
+__all__ = ["CodeGenForCausalLM", "CodeGenModel", "CodeGenPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/tokenization_codegen.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/tokenization_codegen.py
new file mode 100644
index 0000000000000000000000000000000000000000..152b1a84fc37d5ed6613072284133565ebc86cbf
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/tokenization_codegen.py
@@ -0,0 +1,390 @@
+# coding=utf-8
+# Copyright 2022 The Salesforce authors, The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for CodeGen"""
+
+import json
+import os
+from functools import lru_cache
+from typing import TYPE_CHECKING, Optional, Union
+
+import numpy as np
+import regex as re
+
+from ...utils import is_tf_available, is_torch_available, logging, to_py_obj
+
+
+if TYPE_CHECKING:
+    if is_torch_available():
+        import torch
+    if is_tf_available():
+        import tensorflow as tf
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {
+    "vocab_file": "vocab.json",
+    "merges_file": "merges.txt",
+}
+
+
+@lru_cache
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
+    characters the bpe code barfs on.
+
+    The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+    if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
+    decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
+    tables between utf-8 bytes and unicode strings.
+    """
+    bs = (
+        list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+    )
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8 + n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+    """
+    Return set of symbol pairs in a word.
+
+    Word is represented as tuple of symbols (symbols being variable-length strings).
+    """
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+
+class CodeGenTokenizer(PreTrainedTokenizer):
+    """
+    Construct a CodeGen tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+    This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+    be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+    ```python
+    >>> from transformers import CodeGenTokenizer
+
+    >>> tokenizer = CodeGenTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
+    >>> tokenizer("Hello world")["input_ids"]
+    [15496, 995]
+
+    >>> tokenizer(" Hello world")["input_ids"]
+    [18435, 995]
+    ```
+
+    You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
+    call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
+
+    
+
+    When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
+
+    
+
+    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+    this superclass for more information regarding those methods.
+
+    Args:
+        vocab_file (`str`):
+            Path to the vocabulary file.
+        merges_file (`str`):
+            Path to the merges file.
+        errors (`str`, *optional*, defaults to `"replace"`):
+            Paradigm to follow when decoding bytes to UTF-8. See
+            [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+        unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+            token instead.
+        bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+            The beginning of sequence token.
+        eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+            The end of sequence token.
+        pad_token (`str`, *optional*):
+            The token used for padding, for example when batching sequences of different lengths.
+        add_prefix_space (`bool`, *optional*, defaults to `False`):
+            Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+            other word. (CodeGen tokenizer detect beginning of words by the preceding space).
+        add_bos_token (`bool`, *optional*, defaults to `False`):
+            Whether to add a beginning of sequence token at the start of sequences.
+        return_token_type_ids (`bool`, *optional*, defaults to `False`):
+            Whether to return token type IDs.
+    """
+
+    vocab_files_names = VOCAB_FILES_NAMES
+    model_input_names = ["input_ids", "attention_mask"]
+
+    def __init__(
+        self,
+        vocab_file,
+        merges_file,
+        errors="replace",
+        unk_token="<|endoftext|>",
+        bos_token="<|endoftext|>",
+        eos_token="<|endoftext|>",
+        pad_token=None,
+        add_prefix_space=False,
+        add_bos_token=False,
+        return_token_type_ids=False,
+        **kwargs,
+    ):
+        bos_token = AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token
+        eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token
+        unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token
+        pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token
+        self.add_bos_token = add_bos_token
+        self.return_token_type_ids = return_token_type_ids
+        if self.return_token_type_ids:
+            self.model_input_names.append("token_type_ids")
+
+        with open(vocab_file, encoding="utf-8") as vocab_handle:
+            self.encoder = json.load(vocab_handle)
+        self.decoder = {v: k for k, v in self.encoder.items()}
+        self.errors = errors  # how to handle errors in decoding
+        self.byte_encoder = bytes_to_unicode()
+        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+        with open(merges_file, encoding="utf-8") as merges_handle:
+            bpe_merges = merges_handle.read().split("\n")[1:-1]
+        bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
+        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+        self.cache = {}
+        self.add_prefix_space = add_prefix_space
+
+        # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
+        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+        super().__init__(
+            errors=errors,
+            unk_token=unk_token,
+            bos_token=bos_token,
+            eos_token=eos_token,
+            pad_token=pad_token,
+            add_prefix_space=add_prefix_space,
+            add_bos_token=add_bos_token,
+            return_token_type_ids=return_token_type_ids,
+            **kwargs,
+        )
+
+    @property
+    def vocab_size(self):
+        return len(self.encoder)
+
+    def get_vocab(self):
+        return dict(self.encoder, **self.added_tokens_encoder)
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token)
+        pairs = get_pairs(word)
+
+        if not pairs:
+            return token
+
+        while True:
+            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                except ValueError:
+                    new_word.extend(word[i:])
+                    break
+                else:
+                    new_word.extend(word[i:j])
+                    i = j
+
+                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+                    new_word.append(first + second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = " ".join(word)
+        self.cache[token] = word
+        return word
+
+    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+        if self.add_bos_token:
+            bos_token_ids = [self.bos_token_id]
+        else:
+            bos_token_ids = []
+
+        output = bos_token_ids + token_ids_0
+
+        if token_ids_1 is None:
+            return output
+
+        return output + bos_token_ids + token_ids_1
+
+    def _tokenize(self, text):
+        """Tokenize a string."""
+        bpe_tokens = []
+        for token in re.findall(self.pat, text):
+            token = "".join(
+                self.byte_encoder[b] for b in token.encode("utf-8")
+            )  # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
+            bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
+        return bpe_tokens
+
+    def _convert_token_to_id(self, token):
+        """Converts a token (str) in an id using the vocab."""
+        return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+    def _convert_id_to_token(self, index):
+        """Converts an index (integer) in a token (str) using the vocab."""
+        return self.decoder.get(index)
+
+    def convert_tokens_to_string(self, tokens):
+        """Converts a sequence of tokens (string) in a single string."""
+        text = "".join(tokens)
+        text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
+        return text
+
+    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+        if not os.path.isdir(save_directory):
+            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+            return
+        vocab_file = os.path.join(
+            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+        )
+        merge_file = os.path.join(
+            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
+        )
+
+        with open(vocab_file, "w", encoding="utf-8") as f:
+            f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
+
+        index = 0
+        with open(merge_file, "w", encoding="utf-8") as writer:
+            writer.write("#version: 0.2\n")
+            for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+                if index != token_index:
+                    logger.warning(
+                        f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
+                        " Please check that the tokenizer is not corrupted!"
+                    )
+                    index = token_index
+                writer.write(" ".join(bpe_tokens) + "\n")
+                index += 1
+
+        return vocab_file, merge_file
+
+    def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
+        add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
+        if is_split_into_words or add_prefix_space:
+            text = " " + text
+        return (text, kwargs)
+
+    def decode(
+        self,
+        token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
+        skip_special_tokens: bool = False,
+        clean_up_tokenization_spaces: Optional[bool] = None,
+        truncate_before_pattern: Optional[list[str]] = None,
+        **kwargs,
+    ) -> str:
+        """
+        Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
+        tokens and clean up tokenization spaces.
+
+        Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
+
+        Args:
+            token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
+                List of tokenized input ids. Can be obtained using the `__call__` method.
+            skip_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not to remove special tokens in the decoding.
+            clean_up_tokenization_spaces (`bool`, *optional*):
+                Whether or not to clean up the tokenization spaces. If `None`, will default to
+                `self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
+            truncate_before_pattern (`List[str]`, *optional*, defaults to `None`):
+                A list of regular expression strings that will be used to truncate the returned string. This can be
+                used to remove extra pieces of code (e.g. truncate if observing a comment symbol "#" at the beginning
+                of a new line). An example pattern could be `["^#", re.escape("<|endoftext|>"), "^'''", "\n\n\n"]`.
+            kwargs (additional keyword arguments, *optional*):
+                Will be passed to the underlying model specific decode method.
+
+        Returns:
+            `str`: The decoded sentence.
+        """
+
+        token_ids = to_py_obj(token_ids)
+
+        decoded_text = super()._decode(
+            token_ids=token_ids,
+            skip_special_tokens=skip_special_tokens,
+            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+            **kwargs,
+        )
+
+        if truncate_before_pattern is not None and len(truncate_before_pattern) > 0:
+            decoded_text = self.truncate(decoded_text, truncate_before_pattern)
+
+        return decoded_text
+
+    def truncate(self, completion, truncate_before_pattern):
+        def find_re(string, pattern, start_pos):
+            m = pattern.search(string, start_pos)
+            return m.start() if m else -1
+
+        terminals = [re.compile(pattern, re.MULTILINE) for pattern in truncate_before_pattern]
+
+        prints = list(re.finditer("^print", completion, re.MULTILINE))
+
+        if len(prints) > 1:
+            completion = completion[: prints[1].start()]
+
+        defs = list(re.finditer("^def", completion, re.MULTILINE))
+
+        if len(defs) > 1:
+            completion = completion[: defs[1].start()]
+
+        start_pos = 0
+
+        terminals_pos = [
+            pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1
+        ]
+
+        if len(terminals_pos) > 0:
+            return completion[: min(terminals_pos)]
+        else:
+            return completion
+
+
+__all__ = ["CodeGenTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/tokenization_codegen_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/tokenization_codegen_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bac0db7de4e7c548ddea0eb2b3c498919c6196a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/codegen/tokenization_codegen_fast.py
@@ -0,0 +1,235 @@
+# coding=utf-8
+# Copyright 2022 The Salesforce authors, The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for OpenAI GPT."""
+
+import re
+from typing import TYPE_CHECKING, Optional, Union
+
+import numpy as np
+
+from ...utils import is_tf_available, is_torch_available, logging
+
+
+if TYPE_CHECKING:
+    if is_torch_available():
+        import torch
+    if is_tf_available():
+        import tensorflow as tf
+
+
+from ...tokenization_utils_base import BatchEncoding
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from .tokenization_codegen import CodeGenTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
+
+
+class CodeGenTokenizerFast(PreTrainedTokenizerFast):
+    """
+    Construct a "fast" CodeGen tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
+    Byte-Pair-Encoding.
+
+    This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
+    be encoded differently whether it is at the beginning of the sentence (without space) or not:
+
+    ```python
+    >>> from transformers import CodeGenTokenizerFast
+
+    >>> tokenizer = CodeGenTokenizerFast.from_pretrained("Salesforce/codegen-350M-mono")
+    >>> tokenizer("Hello world")["input_ids"]
+    [15496, 995]
+
+    >>> tokenizer(" Hello world")["input_ids"]
+    [18435, 995]
+    ```
+
+    You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
+    the model was not pretrained this way, it might yield a decrease in performance.
+
+    
+
+    When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
+
+    
+
+    This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+    refer to this superclass for more information regarding those methods.
+
+    Args:
+        vocab_file (`str`, *optional*):
+            Path to the vocabulary file.
+        merges_file (`str`, *optional*):
+            Path to the merges file.
+        tokenizer_file (`str`, *optional*):
+            Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
+            contains everything needed to load the tokenizer.
+        unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+            token instead.
+        bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+            The beginning of sequence token.
+        eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+            The end of sequence token.
+        add_prefix_space (`bool`, *optional*, defaults to `False`):
+            Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+            other word. (CodeGen tokenizer detect beginning of words by the preceding space).
+        return_token_type_ids (`bool`, *optional*, defaults to `False`):
+            Whether to return token type IDs.
+    """
+
+    vocab_files_names = VOCAB_FILES_NAMES
+    model_input_names = ["input_ids", "attention_mask"]
+    slow_tokenizer_class = CodeGenTokenizer
+
+    def __init__(
+        self,
+        vocab_file=None,
+        merges_file=None,
+        tokenizer_file=None,
+        unk_token="<|endoftext|>",
+        bos_token="<|endoftext|>",
+        eos_token="<|endoftext|>",
+        add_prefix_space=False,
+        return_token_type_ids=False,
+        **kwargs,
+    ):
+        self.return_token_type_ids = return_token_type_ids
+        if self.return_token_type_ids:
+            self.model_input_names.append("token_type_ids")
+
+        super().__init__(
+            vocab_file,
+            merges_file,
+            tokenizer_file=tokenizer_file,
+            unk_token=unk_token,
+            bos_token=bos_token,
+            eos_token=eos_token,
+            add_prefix_space=add_prefix_space,
+            return_token_type_ids=return_token_type_ids,
+            **kwargs,
+        )
+
+        if kwargs.pop("add_bos_token", False):
+            model_id = kwargs.pop("name_or_path", "")
+            raise ValueError(
+                "Currently GPT2's fast tokenizer does NOT support adding a BOS token. "
+                "Instead you should use GPT2's slow tokenizer class `CodeGenTokenizer` as follows: \n"
+                f"`CodeGenTokenizer.from_pretrained('{model_id}')`\nor\n"
+                f"`AutoTokenizer.from_pretrained('{model_id}', use_fast=False)`\n"
+                "This issue will be fixed soon, see: https://github.com/huggingface/tokenizers/pull/1005."
+                " so that the fast tokenizer works correctly."
+            )
+
+    def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
+        is_split_into_words = kwargs.get("is_split_into_words", False)
+        assert self.add_prefix_space or not is_split_into_words, (
+            f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
+            "to use it with pretokenized inputs."
+        )
+
+        return super()._batch_encode_plus(*args, **kwargs)
+
+    def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
+        is_split_into_words = kwargs.get("is_split_into_words", False)
+
+        assert self.add_prefix_space or not is_split_into_words, (
+            f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
+            "to use it with pretokenized inputs."
+        )
+
+        return super()._encode_plus(*args, **kwargs)
+
+    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+        files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+        return tuple(files)
+
+    def decode(
+        self,
+        token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
+        skip_special_tokens: bool = False,
+        clean_up_tokenization_spaces: Optional[bool] = None,
+        truncate_before_pattern: Optional[list[str]] = None,
+        **kwargs,
+    ) -> str:
+        """
+        Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
+        tokens and clean up tokenization spaces.
+
+        Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
+
+        Args:
+            token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
+                List of tokenized input ids. Can be obtained using the `__call__` method.
+            skip_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not to remove special tokens in the decoding.
+            clean_up_tokenization_spaces (`bool`, *optional*):
+                Whether or not to clean up the tokenization spaces. If `None`, will default to
+                `self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
+            truncate_before_pattern (`List[str]`, *optional*, defaults to `None`):
+                A list of regular expression strings that will be used to truncate the returned string. This can be
+                used to remove extra pieces of code (e.g. truncate if observing a comment symbol "#" at the beginning
+                of a new line). An example pattern could be `["^#", re.escape("<|endoftext|>"), "^'''", "\n\n\n"]`.
+            kwargs (additional keyword arguments, *optional*):
+                Will be passed to the underlying model specific decode method.
+
+        Returns:
+            `str`: The decoded sentence.
+        """
+
+        decoded_text = super().decode(
+            token_ids=token_ids,
+            skip_special_tokens=skip_special_tokens,
+            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+            **kwargs,
+        )
+
+        if truncate_before_pattern is not None and len(truncate_before_pattern) > 0:
+            decoded_text = self.truncate(decoded_text, truncate_before_pattern)
+
+        return decoded_text
+
+    def truncate(self, completion, truncate_before_pattern):
+        def find_re(string, pattern, start_pos):
+            m = pattern.search(string, start_pos)
+            return m.start() if m else -1
+
+        terminals = [re.compile(pattern, re.MULTILINE) for pattern in truncate_before_pattern]
+
+        prints = list(re.finditer("^print", completion, re.MULTILINE))
+
+        if len(prints) > 1:
+            completion = completion[: prints[1].start()]
+
+        defs = list(re.finditer("^def", completion, re.MULTILINE))
+
+        if len(defs) > 1:
+            completion = completion[: defs[1].start()]
+
+        start_pos = 0
+
+        terminals_pos = [
+            pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1
+        ]
+
+        if len(terminals_pos) > 0:
+            return completion[: min(terminals_pos)]
+        else:
+            return completion
+
+
+__all__ = ["CodeGenTokenizerFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1447f65935601f0fffd8a88dac25bc5916b35f83
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Cohere and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_cohere2 import *
+    from .modeling_cohere2 import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/configuration_cohere2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/configuration_cohere2.py
new file mode 100644
index 0000000000000000000000000000000000000000..c92f63cad312651fd750fb12156166f95ce3d8d4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/configuration_cohere2.py
@@ -0,0 +1,232 @@
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+#           This file was automatically generated from src/transformers/models/cohere2/modular_cohere2.py.
+#               Do NOT edit this file manually as any edits will be overwritten by the generation of
+#             the file from the modular. If any change should be done, please apply the change to the
+#                          modular_cohere2.py file directly. One of our CI enforces this.
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Cohere Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...modeling_rope_utils import rope_config_validation
+
+
+class Cohere2Config(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`CohereModel`]. It is used to instantiate an Cohere
+    model according to the specified arguments, defining the model architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information. Instantiating a configuration
+    with the defaults will yield a similar configuration to that of the [CohereForAI/c4ai-command-r-v01](https://huggingface.co/CohereForAI/c4ai-command-r-v01) model.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 256000):
+            Vocabulary size of the Cohere model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`CohereModel`]
+        hidden_size (`int`, *optional*, defaults to 8192):
+            Dimension of the hidden representations.
+        intermediate_size (`int`, *optional*, defaults to 22528):
+            Dimension of the MLP representations.
+        logit_scale (`float`, *optional*, defaults to 0.0625):
+            The scaling factor for the output logits.
+        num_hidden_layers (`int`, *optional*, defaults to 40):
+            Number of hidden layers in the Transformer decoder.
+        num_attention_heads (`int`, *optional*, defaults to 64):
+            Number of attention heads for each attention layer in the Transformer decoder.
+        num_key_value_heads (`int`, *optional*):
+            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+            by meanpooling all the original heads within that group. For more details, check out [this
+            paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+            `num_attention_heads`.
+        hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+            The non-linear activation function (function or string) in the decoder.
+        max_position_embeddings (`int`, *optional*, defaults to 8192):
+            The maximum sequence length that this model might ever be used with.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+            The epsilon used by the layer normalization.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models). Only
+            relevant if `config.is_decoder=True`.
+        pad_token_id (`int`, *optional*, defaults to 0):
+            Padding token id.
+        bos_token_id (`int`, *optional*, defaults to 5):
+            Beginning of stream token id.
+        eos_token_id (`int`, *optional*, defaults to 255001):
+            End of stream token id.
+        tie_word_embeddings (`bool`, *optional*, defaults to `True`):
+            Whether to tie weight embeddings
+        rope_theta (`float`, *optional*, defaults to 10000.0):
+            The base period of the RoPE embeddings.
+        rope_scaling (`Dict`, *optional*):
+            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+            accordingly.
+            Expected contents:
+                `rope_type` (`str`):
+                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+                    'llama3'], with 'default' being the original RoPE implementation.
+                `factor` (`float`, *optional*):
+                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+                    original maximum pre-trained length.
+                `original_max_position_embeddings` (`int`, *optional*):
+                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+                    pretraining.
+                `attention_factor` (`float`, *optional*):
+                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+                    computation. If unspecified, it defaults to value recommended by the implementation, using the
+                    `factor` field to infer the suggested value.
+                `beta_fast` (`float`, *optional*):
+                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+                    ramp function. If unspecified, it defaults to 32.
+                `beta_slow` (`float`, *optional*):
+                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+                    ramp function. If unspecified, it defaults to 1.
+                `short_factor` (`list[float]`, *optional*):
+                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+                    size divided by the number of attention heads divided by 2
+                `long_factor` (`list[float]`, *optional*):
+                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+                    size divided by the number of attention heads divided by 2
+                `low_freq_factor` (`float`, *optional*):
+                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+                `high_freq_factor` (`float`, *optional*):
+                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+            Whether to use a bias in the query, key, value and output projection layers during self-attention.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        sliding_window (`int`, *optional*, defaults to 4096):
+            Size of the sliding window attention context.
+        layer_types (`list`, *optional*):
+            Attention pattern for each layer.
+
+    ```python
+    >>> from transformers import Cohere2Model, Cohere2Config
+
+    >>> # Initializing a Cohere Nextmodel configuration
+    >>> configuration = Cohere2Config()
+
+    >>> # Initializing a model from the Cohere2 configuration
+    >>> model = Cohere2Model(configuration) # doctest: +SKIP
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config # doctest: +SKIP
+    ```
+    """
+
+    model_type = "cohere2"
+    keys_to_ignore_at_inference = ["past_key_values"]
+    base_model_tp_plan = {
+        "layers.*.self_attn.q_proj": "colwise",
+        "layers.*.self_attn.k_proj": "colwise",
+        "layers.*.self_attn.v_proj": "colwise",
+        "layers.*.self_attn.o_proj": "rowwise",
+        "layers.*.mlp.gate_proj": "colwise",
+        "layers.*.mlp.up_proj": "colwise",
+        "layers.*.mlp.down_proj": "rowwise",
+    }
+    base_model_pp_plan = {
+        "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+        "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+        "norm": (["hidden_states"], ["hidden_states"]),
+    }
+
+    def __init__(
+        self,
+        vocab_size=256000,
+        hidden_size=8192,
+        intermediate_size=22528,
+        logit_scale=0.0625,
+        num_hidden_layers=40,
+        num_attention_heads=64,
+        num_key_value_heads=None,
+        hidden_act="silu",
+        max_position_embeddings=8192,
+        initializer_range=0.02,
+        layer_norm_eps=1e-5,
+        use_cache=True,
+        pad_token_id=0,
+        bos_token_id=5,
+        eos_token_id=255001,
+        tie_word_embeddings=True,
+        rope_theta=10000.0,
+        rope_scaling=None,
+        attention_bias=False,
+        attention_dropout=0.0,
+        sliding_window=4096,
+        layer_types=None,
+        **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.hidden_size = hidden_size
+        self.logit_scale = logit_scale
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+
+        # for backward compatibility
+        if num_key_value_heads is None:
+            num_key_value_heads = num_attention_heads
+
+        self.num_key_value_heads = num_key_value_heads
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.layer_norm_eps = layer_norm_eps
+        self.use_cache = use_cache
+        self.rope_theta = rope_theta
+        self.rope_scaling = rope_scaling
+        self.attention_bias = attention_bias
+        self.attention_dropout = attention_dropout
+        self.sliding_window = sliding_window
+        self.layer_types = layer_types
+        # Need to specify head_dim in the config so it can be used in the attention forward functions
+        self.head_dim = hidden_size // num_attention_heads
+
+        # Validate the correctness of rotary position embeddings parameters
+        rope_config_validation(self)
+
+        super().__init__(
+            pad_token_id=pad_token_id,
+            bos_token_id=bos_token_id,
+            eos_token_id=eos_token_id,
+            tie_word_embeddings=tie_word_embeddings,
+            **kwargs,
+        )
+
+        # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
+        self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 4)
+
+        if self.layer_types is None:
+            # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
+            self._sliding_window_pattern = getattr(self, "sliding_window_pattern", 4)
+            self.layer_types = [
+                "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
+                for i in range(self.num_hidden_layers)
+            ]
+        layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+
+__all__ = ["Cohere2Config"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/modeling_cohere2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/modeling_cohere2.py
new file mode 100644
index 0000000000000000000000000000000000000000..bab804aab67ec2fa1006cb70ac83766e8f2e71aa
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/modeling_cohere2.py
@@ -0,0 +1,513 @@
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+#           This file was automatically generated from src/transformers/models/cohere2/modular_cohere2.py.
+#               Do NOT edit this file manually as any edits will be overwritten by the generation of
+#             the file from the modular. If any change should be done, please apply the change to the
+#                          modular_cohere2.py file directly. One of our CI enforces this.
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Cohere Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional, Union
+
+import torch
+import torch.nn as nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_cohere2 import Cohere2Config
+
+
+class Cohere2RotaryEmbedding(nn.Module):
+    inv_freq: torch.Tensor  # fix linting for `register_buffer`
+
+    def __init__(self, config: Cohere2Config, device=None):
+        super().__init__()
+        # BC: "rope_type" was originally "type"
+        if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+        else:
+            self.rope_type = "default"
+        self.max_seq_len_cached = config.max_position_embeddings
+        self.original_max_seq_len = config.max_position_embeddings
+
+        self.config = config
+        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
+        self.original_inv_freq = self.inv_freq
+
+    @torch.no_grad()
+    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
+    def forward(self, x, position_ids):
+        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+        position_ids_expanded = position_ids[:, None, :].float()
+
+        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+        with torch.autocast(device_type=device_type, enabled=False):  # Force float32
+            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+            emb = torch.repeat_interleave(freqs, 2, dim=-1)  # diff from Llama: we interleave() instead of cat()
+            cos = emb.cos() * self.attention_scaling
+            sin = emb.sin() * self.attention_scaling
+
+        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class Cohere2LayerNorm(nn.Module):
+    def __init__(self, hidden_size=None, eps=1e-5, bias=False):
+        """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states):
+        input_dtype = hidden_states.dtype
+        hidden_states = hidden_states.to(torch.float32)
+        mean = hidden_states.mean(-1, keepdim=True)
+        variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
+        hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
+        hidden_states = self.weight.to(torch.float32) * hidden_states
+        return hidden_states.to(input_dtype)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+    """
+    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+    """
+    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+    if n_rep == 1:
+        return hidden_states
+    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+    module: nn.Module,
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attention_mask: Optional[torch.Tensor],
+    scaling: float,
+    dropout: float = 0.0,
+    **kwargs: Unpack[TransformersKwargs],
+):
+    key_states = repeat_kv(key, module.num_key_value_groups)
+    value_states = repeat_kv(value, module.num_key_value_groups)
+
+    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+    if attention_mask is not None:
+        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+        attn_weights = attn_weights + causal_mask
+
+    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+    attn_output = torch.matmul(attn_weights, value_states)
+    attn_output = attn_output.transpose(1, 2).contiguous()
+
+    return attn_output, attn_weights
+
+
+def rotate_half(x):
+    # Split and rotate. Note that this function is different from e.g. Llama.
+    x1 = x[..., ::2]
+    x2 = x[..., 1::2]
+    rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
+    return rot_x
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+    """Applies Rotary Position Embedding to the query and key tensors.
+
+    Args:
+        q (`torch.Tensor`): The query tensor.
+        k (`torch.Tensor`): The key tensor.
+        cos (`torch.Tensor`): The cosine part of the rotary embedding.
+        sin (`torch.Tensor`): The sine part of the rotary embedding.
+        position_ids (`torch.Tensor`, *optional*):
+            Deprecated and unused.
+        unsqueeze_dim (`int`, *optional*, defaults to 1):
+            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+    Returns:
+        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+    """
+    dtype = q.dtype
+    q = q.float()
+    k = k.float()
+    cos = cos.unsqueeze(unsqueeze_dim)
+    sin = sin.unsqueeze(unsqueeze_dim)
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
+
+
+class Cohere2Attention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None):
+        super().__init__()
+        self.config = config
+        self.layer_idx = layer_idx
+        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+        self.scaling = self.head_dim**-0.5
+        self.attention_dropout = config.attention_dropout
+        self.is_causal = True
+        self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
+
+        self.q_proj = nn.Linear(
+            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+        )
+        self.k_proj = nn.Linear(
+            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+        )
+        self.v_proj = nn.Linear(
+            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+        )
+        self.o_proj = nn.Linear(
+            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+        )
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_embeddings: tuple[torch.Tensor, torch.Tensor],
+        attention_mask: Optional[torch.Tensor],
+        past_key_values: Optional[Cache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs: Unpack[FlashAttentionKwargs],
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+        input_shape = hidden_states.shape[:-1]
+        hidden_shape = (*input_shape, -1, self.head_dim)
+
+        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+        cos, sin = position_embeddings
+        if self.sliding_window is not None:
+            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_values is not None:
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+        attention_interface: Callable = eager_attention_forward
+        if self.config._attn_implementation != "eager":
+            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+        attn_output, attn_weights = attention_interface(
+            self,
+            query_states,
+            key_states,
+            value_states,
+            attention_mask,
+            dropout=0.0 if not self.training else self.attention_dropout,
+            scaling=self.scaling,
+            sliding_window=self.sliding_window,
+            **kwargs,
+        )
+
+        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+        attn_output = self.o_proj(attn_output)
+        return attn_output, attn_weights
+
+
+class Cohere2MLP(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size
+        self.intermediate_size = config.intermediate_size
+        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+        self.act_fn = ACT2FN[config.hidden_act]
+
+    def forward(self, x):
+        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+        return down_proj
+
+
+class Cohere2DecoderLayer(GradientCheckpointingLayer):
+    def __init__(self, config: Cohere2Config, layer_idx: int):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.self_attn = Cohere2Attention(config=config, layer_idx=layer_idx)
+        self.mlp = Cohere2MLP(config)
+        self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
+        self.attention_type = config.layer_types[layer_idx]
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_embeddings: tuple[torch.Tensor, torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Cache] = None,
+        use_cache: Optional[bool] = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs: Unpack[FlashAttentionKwargs],
+    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`, *optional*):
+                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+                query_sequence_length, key_sequence_length)` if default attention is used.
+            past_key_values (`Cache`, *optional*): cached past key and value projection states
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+                Indices depicting the position of the input sequence tokens in the sequence
+            position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+                with `head_dim` being the embedding dimension of each attention head.
+        """
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+        hidden_states_attention, _ = self.self_attn(
+            hidden_states=hidden_states,
+            position_embeddings=position_embeddings,
+            attention_mask=attention_mask,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            cache_position=cache_position,
+            **kwargs,
+        )
+
+        hidden_states_mlp = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states_attention + hidden_states_mlp
+        return hidden_states
+
+
+@auto_docstring
+class Cohere2PreTrainedModel(PreTrainedModel):
+    config: Cohere2Config
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["Cohere2DecoderLayer"]
+    _skip_keys_device_placement = ["past_key_values"]
+    _supports_flash_attn = True
+    _supports_sdpa = True
+    _supports_flex_attn = True
+
+    _can_compile_fullgraph = True
+    _supports_attention_backend = True
+    _can_record_outputs = {
+        "hidden_states": Cohere2DecoderLayer,
+        "attentions": Cohere2Attention,
+    }
+
+
+@auto_docstring
+class Cohere2Model(Cohere2PreTrainedModel):
+    def __init__(self, config: Cohere2Config):
+        super().__init__(config)
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+
+        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+        self.layers = nn.ModuleList(
+            [Cohere2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+        )
+        self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
+        self.rotary_emb = Cohere2RotaryEmbedding(config=config)
+        self.gradient_checkpointing = False
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @check_model_inputs
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> BaseModelOutputWithPast:
+        if (input_ids is None) ^ (inputs_embeds is not None):
+            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        if use_cache and past_key_values is None and not self.training:
+            past_key_values = DynamicCache(config=self.config)
+
+        if cache_position is None:
+            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+            cache_position = torch.arange(
+                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+            )
+        if position_ids is None:
+            position_ids = cache_position.unsqueeze(0)
+
+        if not isinstance(causal_mask_mapping := attention_mask, dict):
+            mask_kwargs = {
+                "config": self.config,
+                "input_embeds": inputs_embeds,
+                "attention_mask": attention_mask,
+                "cache_position": cache_position,
+                "past_key_values": past_key_values,
+                "position_ids": position_ids,
+            }
+            causal_mask_mapping = {
+                "full_attention": create_causal_mask(**mask_kwargs),
+                "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+            }
+
+        hidden_states = inputs_embeds
+        position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+        for decoder_layer in self.layers:
+            hidden_states = decoder_layer(
+                hidden_states,
+                position_embeddings=position_embeddings,
+                attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+                past_key_values=past_key_values,
+                use_cache=use_cache,
+                cache_position=cache_position,
+                **kwargs,
+            )
+
+        hidden_states = self.norm(hidden_states)
+        return BaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=past_key_values,
+        )
+
+
+@auto_docstring
+class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
+    _tied_weights_keys = ["lm_head.weight"]
+    _tp_plan = {"lm_head": "colwise_rep"}
+    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.model = Cohere2Model(config)
+        self.vocab_size = config.vocab_size
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+        self.logit_scale = config.logit_scale
+        self.tie_word_embeddings = config.tie_word_embeddings
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @can_return_tuple
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        logits_to_keep: Union[int, torch.Tensor] = 0,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> CausalLMOutputWithPast:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+        Example:
+
+        ```python
+        >> from transformers import AutoTokenizer, Cohere2ForCausalLM
+
+        >> model = Cohere2ForCausalLM.from_pretrained("Cohere2ForAI/c4ai-command-r-v01")
+        >> tokenizer = AutoTokenizer.from_pretrained("Cohere2ForAI/c4ai-command-r-v01")
+
+        >> prompt = "Hey, are you conscious? Can you talk to me?"
+        >> inputs = tokenizer(prompt, return_tensors="pt")
+
+        >> # Generate
+        >> generate_ids = model.generate(inputs.input_ids, max_length=30)
+        >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+        ```"""
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+
+        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+        outputs: BaseModelOutputWithPast = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            cache_position=cache_position,
+            **kwargs,
+        )
+
+        hidden_states = outputs.last_hidden_state
+        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+        logits = self.lm_head(hidden_states[:, slice_indices, :])
+        logits = logits * self.logit_scale  # main diff from Llama
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+        return CausalLMOutputWithPast(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+__all__ = ["Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/modular_cohere2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/modular_cohere2.py
new file mode 100644
index 0000000000000000000000000000000000000000..91ed748e0361841b9e9df272a3feb6884fcd6fd9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/cohere2/modular_cohere2.py
@@ -0,0 +1,444 @@
+# coding=utf-8
+# Copyright 2024 Cohere Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional
+
+import torch
+import torch.nn as nn
+
+from ...cache_utils import Cache, DynamicCache
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutputWithPast
+from ...modeling_rope_utils import rope_config_validation
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, logging
+from ...utils.deprecation import deprecate_kwarg
+from ..cohere.modeling_cohere import (
+    CohereAttention,
+    CohereDecoderLayer,
+    CohereForCausalLM,
+    CohereLayerNorm,
+    CoherePreTrainedModel,
+    CohereRotaryEmbedding,
+    apply_rotary_pos_emb,
+    eager_attention_forward,
+)
+from ..gemma2.modeling_gemma2 import Gemma2Model
+
+
+logger = logging.get_logger(__name__)
+
+
+class Cohere2Config(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`CohereModel`]. It is used to instantiate an Cohere
+    model according to the specified arguments, defining the model architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information. Instantiating a configuration
+    with the defaults will yield a similar configuration to that of the [CohereForAI/c4ai-command-r-v01](https://huggingface.co/CohereForAI/c4ai-command-r-v01) model.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 256000):
+            Vocabulary size of the Cohere model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`CohereModel`]
+        hidden_size (`int`, *optional*, defaults to 8192):
+            Dimension of the hidden representations.
+        intermediate_size (`int`, *optional*, defaults to 22528):
+            Dimension of the MLP representations.
+        logit_scale (`float`, *optional*, defaults to 0.0625):
+            The scaling factor for the output logits.
+        num_hidden_layers (`int`, *optional*, defaults to 40):
+            Number of hidden layers in the Transformer decoder.
+        num_attention_heads (`int`, *optional*, defaults to 64):
+            Number of attention heads for each attention layer in the Transformer decoder.
+        num_key_value_heads (`int`, *optional*):
+            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+            by meanpooling all the original heads within that group. For more details, check out [this
+            paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+            `num_attention_heads`.
+        hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+            The non-linear activation function (function or string) in the decoder.
+        max_position_embeddings (`int`, *optional*, defaults to 8192):
+            The maximum sequence length that this model might ever be used with.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+            The epsilon used by the layer normalization.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models). Only
+            relevant if `config.is_decoder=True`.
+        pad_token_id (`int`, *optional*, defaults to 0):
+            Padding token id.
+        bos_token_id (`int`, *optional*, defaults to 5):
+            Beginning of stream token id.
+        eos_token_id (`int`, *optional*, defaults to 255001):
+            End of stream token id.
+        tie_word_embeddings (`bool`, *optional*, defaults to `True`):
+            Whether to tie weight embeddings
+        rope_theta (`float`, *optional*, defaults to 10000.0):
+            The base period of the RoPE embeddings.
+        rope_scaling (`Dict`, *optional*):
+            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+            accordingly.
+            Expected contents:
+                `rope_type` (`str`):
+                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+                    'llama3'], with 'default' being the original RoPE implementation.
+                `factor` (`float`, *optional*):
+                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+                    original maximum pre-trained length.
+                `original_max_position_embeddings` (`int`, *optional*):
+                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+                    pretraining.
+                `attention_factor` (`float`, *optional*):
+                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+                    computation. If unspecified, it defaults to value recommended by the implementation, using the
+                    `factor` field to infer the suggested value.
+                `beta_fast` (`float`, *optional*):
+                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+                    ramp function. If unspecified, it defaults to 32.
+                `beta_slow` (`float`, *optional*):
+                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+                    ramp function. If unspecified, it defaults to 1.
+                `short_factor` (`list[float]`, *optional*):
+                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+                    size divided by the number of attention heads divided by 2
+                `long_factor` (`list[float]`, *optional*):
+                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+                    size divided by the number of attention heads divided by 2
+                `low_freq_factor` (`float`, *optional*):
+                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+                `high_freq_factor` (`float`, *optional*):
+                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+            Whether to use a bias in the query, key, value and output projection layers during self-attention.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        sliding_window (`int`, *optional*, defaults to 4096):
+            Size of the sliding window attention context.
+        layer_types (`list`, *optional*):
+            Attention pattern for each layer.
+
+    ```python
+    >>> from transformers import Cohere2Model, Cohere2Config
+
+    >>> # Initializing a Cohere Nextmodel configuration
+    >>> configuration = Cohere2Config()
+
+    >>> # Initializing a model from the Cohere2 configuration
+    >>> model = Cohere2Model(configuration) # doctest: +SKIP
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config # doctest: +SKIP
+    ```
+    """
+
+    model_type = "cohere2"
+    keys_to_ignore_at_inference = ["past_key_values"]
+    base_model_tp_plan = {
+        "layers.*.self_attn.q_proj": "colwise",
+        "layers.*.self_attn.k_proj": "colwise",
+        "layers.*.self_attn.v_proj": "colwise",
+        "layers.*.self_attn.o_proj": "rowwise",
+        "layers.*.mlp.gate_proj": "colwise",
+        "layers.*.mlp.up_proj": "colwise",
+        "layers.*.mlp.down_proj": "rowwise",
+    }
+    base_model_pp_plan = {
+        "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+        "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+        "norm": (["hidden_states"], ["hidden_states"]),
+    }
+
+    def __init__(
+        self,
+        vocab_size=256000,
+        hidden_size=8192,
+        intermediate_size=22528,
+        logit_scale=0.0625,
+        num_hidden_layers=40,
+        num_attention_heads=64,
+        num_key_value_heads=None,
+        hidden_act="silu",
+        max_position_embeddings=8192,
+        initializer_range=0.02,
+        layer_norm_eps=1e-5,
+        use_cache=True,
+        pad_token_id=0,
+        bos_token_id=5,
+        eos_token_id=255001,
+        tie_word_embeddings=True,
+        rope_theta=10000.0,
+        rope_scaling=None,
+        attention_bias=False,
+        attention_dropout=0.0,
+        sliding_window=4096,
+        layer_types=None,
+        **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.hidden_size = hidden_size
+        self.logit_scale = logit_scale
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+
+        # for backward compatibility
+        if num_key_value_heads is None:
+            num_key_value_heads = num_attention_heads
+
+        self.num_key_value_heads = num_key_value_heads
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.layer_norm_eps = layer_norm_eps
+        self.use_cache = use_cache
+        self.rope_theta = rope_theta
+        self.rope_scaling = rope_scaling
+        self.attention_bias = attention_bias
+        self.attention_dropout = attention_dropout
+        self.sliding_window = sliding_window
+        self.layer_types = layer_types
+        # Need to specify head_dim in the config so it can be used in the attention forward functions
+        self.head_dim = hidden_size // num_attention_heads
+
+        # Validate the correctness of rotary position embeddings parameters
+        rope_config_validation(self)
+
+        super().__init__(
+            pad_token_id=pad_token_id,
+            bos_token_id=bos_token_id,
+            eos_token_id=eos_token_id,
+            tie_word_embeddings=tie_word_embeddings,
+            **kwargs,
+        )
+
+        # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
+        self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 4)
+
+        if self.layer_types is None:
+            # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
+            self._sliding_window_pattern = getattr(self, "sliding_window_pattern", 4)
+            self.layer_types = [
+                "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
+                for i in range(self.num_hidden_layers)
+            ]
+        layer_type_validation(self.layer_types, self.num_hidden_layers)
+
+
+class Cohere2RotaryEmbedding(CohereRotaryEmbedding):
+    pass
+
+
+class Cohere2LayerNorm(CohereLayerNorm):
+    pass
+
+
+class Cohere2Attention(CohereAttention):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None):
+        nn.Module.__init__(self)
+        self.config = config
+        self.layer_idx = layer_idx
+        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+        self.scaling = self.head_dim**-0.5
+        self.attention_dropout = config.attention_dropout
+        self.is_causal = True
+        self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
+
+        self.q_proj = nn.Linear(
+            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+        )
+        self.k_proj = nn.Linear(
+            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+        )
+        self.v_proj = nn.Linear(
+            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+        )
+        self.o_proj = nn.Linear(
+            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+        )
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_embeddings: tuple[torch.Tensor, torch.Tensor],
+        attention_mask: Optional[torch.Tensor],
+        past_key_values: Optional[Cache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs: Unpack[FlashAttentionKwargs],
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+        input_shape = hidden_states.shape[:-1]
+        hidden_shape = (*input_shape, -1, self.head_dim)
+
+        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+        cos, sin = position_embeddings
+        if self.sliding_window is not None:
+            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_values is not None:
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+        attention_interface: Callable = eager_attention_forward
+        if self.config._attn_implementation != "eager":
+            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+        attn_output, attn_weights = attention_interface(
+            self,
+            query_states,
+            key_states,
+            value_states,
+            attention_mask,
+            dropout=0.0 if not self.training else self.attention_dropout,
+            scaling=self.scaling,
+            sliding_window=self.sliding_window,
+            **kwargs,
+        )
+
+        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+        attn_output = self.o_proj(attn_output)
+        return attn_output, attn_weights
+
+
+class Cohere2DecoderLayer(CohereDecoderLayer):
+    def __init__(self, config: Cohere2Config, layer_idx: int):
+        super().__init__(config, layer_idx)
+        self.attention_type = config.layer_types[layer_idx]
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_embeddings: tuple[torch.Tensor, torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Cache] = None,
+        use_cache: Optional[bool] = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs: Unpack[FlashAttentionKwargs],
+    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+        hidden_states_attention, _ = self.self_attn(
+            hidden_states=hidden_states,
+            position_embeddings=position_embeddings,
+            attention_mask=attention_mask,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            cache_position=cache_position,
+            **kwargs,
+        )
+
+        hidden_states_mlp = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states_attention + hidden_states_mlp
+        return hidden_states
+
+
+class Cohere2PreTrainedModel(CoherePreTrainedModel):
+    config: Cohere2Config
+
+
+class Cohere2Model(Gemma2Model):
+    def __init__(self, config: Cohere2Config):
+        super().__init__(config)
+        self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
+        self.rotary_emb = Cohere2RotaryEmbedding(config=config)
+
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> BaseModelOutputWithPast:
+        if (input_ids is None) ^ (inputs_embeds is not None):
+            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        if use_cache and past_key_values is None and not self.training:
+            past_key_values = DynamicCache(config=self.config)
+
+        if cache_position is None:
+            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+            cache_position = torch.arange(
+                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+            )
+        if position_ids is None:
+            position_ids = cache_position.unsqueeze(0)
+
+        if not isinstance(causal_mask_mapping := attention_mask, dict):
+            mask_kwargs = {
+                "config": self.config,
+                "input_embeds": inputs_embeds,
+                "attention_mask": attention_mask,
+                "cache_position": cache_position,
+                "past_key_values": past_key_values,
+                "position_ids": position_ids,
+            }
+            causal_mask_mapping = {
+                "full_attention": create_causal_mask(**mask_kwargs),
+                "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+            }
+
+        hidden_states = inputs_embeds
+        position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+        for decoder_layer in self.layers:
+            hidden_states = decoder_layer(
+                hidden_states,
+                position_embeddings=position_embeddings,
+                attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+                past_key_values=past_key_values,
+                use_cache=use_cache,
+                cache_position=cache_position,
+                **kwargs,
+            )
+
+        hidden_states = self.norm(hidden_states)
+        return BaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=past_key_values,
+        )
+
+
+class Cohere2ForCausalLM(CohereForCausalLM):
+    pass
+
+
+__all__ = ["Cohere2Config", "Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2d826745f5b2e011997179ac0dd3d3cfc14389d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_convnext import *
+    from .feature_extraction_convnext import *
+    from .image_processing_convnext import *
+    from .image_processing_convnext_fast import *
+    from .modeling_convnext import *
+    from .modeling_tf_convnext import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/configuration_convnext.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/configuration_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..f54cba58cf296e8c8e3bae70a9f2e2ab21e3c660
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/configuration_convnext.py
@@ -0,0 +1,142 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ConvNeXT model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class ConvNextConfig(BackboneConfigMixin, PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`ConvNextModel`]. It is used to instantiate an
+    ConvNeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration
+    with the defaults will yield a similar configuration to that of the ConvNeXT
+    [facebook/convnext-tiny-224](https://huggingface.co/facebook/convnext-tiny-224) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        num_channels (`int`, *optional*, defaults to 3):
+            The number of input channels.
+        patch_size (`int`, *optional*, defaults to 4):
+            Patch size to use in the patch embedding layer.
+        num_stages (`int`, *optional*, defaults to 4):
+            The number of stages in the model.
+        hidden_sizes (`list[int]`, *optional*, defaults to [96, 192, 384, 768]):
+            Dimensionality (hidden size) at each stage.
+        depths (`list[int]`, *optional*, defaults to [3, 3, 9, 3]):
+            Depth (number of blocks) for each stage.
+        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in each block. If string, `"gelu"`, `"relu"`,
+            `"selu"` and `"gelu_new"` are supported.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+            The epsilon used by the layer normalization layers.
+        layer_scale_init_value (`float`, *optional*, defaults to 1e-6):
+            The initial value for the layer scale.
+        drop_path_rate (`float`, *optional*, defaults to 0.0):
+            The drop rate for stochastic depth.
+        out_features (`list[str]`, *optional*):
+            If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+            (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+            corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+            same order as defined in the `stage_names` attribute.
+        out_indices (`list[int]`, *optional*):
+            If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+            many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+            If unset and `out_features` is unset, will default to the last stage. Must be in the
+            same order as defined in the `stage_names` attribute.
+
+    Example:
+    ```python
+    >>> from transformers import ConvNextConfig, ConvNextModel
+
+    >>> # Initializing a ConvNext convnext-tiny-224 style configuration
+    >>> configuration = ConvNextConfig()
+
+    >>> # Initializing a model (with random weights) from the convnext-tiny-224 style configuration
+    >>> model = ConvNextModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "convnext"
+
+    def __init__(
+        self,
+        num_channels=3,
+        patch_size=4,
+        num_stages=4,
+        hidden_sizes=None,
+        depths=None,
+        hidden_act="gelu",
+        initializer_range=0.02,
+        layer_norm_eps=1e-12,
+        layer_scale_init_value=1e-6,
+        drop_path_rate=0.0,
+        image_size=224,
+        out_features=None,
+        out_indices=None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.num_channels = num_channels
+        self.patch_size = patch_size
+        self.num_stages = num_stages
+        self.hidden_sizes = [96, 192, 384, 768] if hidden_sizes is None else hidden_sizes
+        self.depths = [3, 3, 9, 3] if depths is None else depths
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.layer_norm_eps = layer_norm_eps
+        self.layer_scale_init_value = layer_scale_init_value
+        self.drop_path_rate = drop_path_rate
+        self.image_size = image_size
+        self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
+        self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+            out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+        )
+
+
+class ConvNextOnnxConfig(OnnxConfig):
+    torch_onnx_minimum_version = version.parse("1.11")
+
+    @property
+    def inputs(self) -> Mapping[str, Mapping[int, str]]:
+        return OrderedDict(
+            [
+                ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
+            ]
+        )
+
+    @property
+    def atol_for_validation(self) -> float:
+        return 1e-5
+
+
+__all__ = ["ConvNextConfig", "ConvNextOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/feature_extraction_convnext.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/feature_extraction_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fbb5184cf37cb0aedcafeab3a6b363ab047d9a0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/feature_extraction_convnext.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for ConvNeXT."""
+
+import warnings
+
+from ...utils import logging
+from ...utils.import_utils import requires
+from .image_processing_convnext import ConvNextImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class ConvNextFeatureExtractor(ConvNextImageProcessor):
+    def __init__(self, *args, **kwargs) -> None:
+        warnings.warn(
+            "The class ConvNextFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
+            " Please use ConvNextImageProcessor instead.",
+            FutureWarning,
+        )
+        super().__init__(*args, **kwargs)
+
+
+__all__ = ["ConvNextFeatureExtractor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/image_processing_convnext.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/image_processing_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..af89274500ddb33160e92b5591bc9ac83c55f24c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/image_processing_convnext.py
@@ -0,0 +1,325 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for ConvNeXT."""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+    center_crop,
+    get_resize_output_image_size,
+    resize,
+    to_channel_dimension_format,
+)
+from ...image_utils import (
+    IMAGENET_STANDARD_MEAN,
+    IMAGENET_STANDARD_STD,
+    ChannelDimension,
+    ImageInput,
+    PILImageResampling,
+    infer_channel_dimension_format,
+    is_scaled_image,
+    make_flat_list_of_images,
+    to_numpy_array,
+    valid_images,
+    validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+from ...utils.import_utils import requires
+
+
+if is_vision_available():
+    import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class ConvNextImageProcessor(BaseImageProcessor):
+    r"""
+    Constructs a ConvNeXT image processor.
+
+    Args:
+        do_resize (`bool`, *optional*, defaults to `True`):
+            Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden
+            by `do_resize` in the `preprocess` method.
+        size (`dict[str, int]` *optional*, defaults to `{"shortest_edge": 384}`):
+            Resolution of the output image after `resize` is applied. If `size["shortest_edge"]` >= 384, the image is
+            resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the image will
+            be matched to `int(size["shortest_edge"]/crop_pct)`, after which the image is cropped to
+            `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. Can
+            be overridden by `size` in the `preprocess` method.
+        crop_pct (`float` *optional*, defaults to 224 / 256):
+            Percentage of the image to crop. Only has an effect if `do_resize` is `True` and size < 384. Can be
+            overridden by `crop_pct` in the `preprocess` method.
+        resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+            Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+        do_rescale (`bool`, *optional*, defaults to `True`):
+            Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+            the `preprocess` method.
+        rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+            Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+            method.
+        do_normalize (`bool`, *optional*, defaults to `True`):
+            Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+            method.
+        image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+            Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+            channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+        image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+            Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+            number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+    """
+
+    model_input_names = ["pixel_values"]
+
+    def __init__(
+        self,
+        do_resize: bool = True,
+        size: Optional[dict[str, int]] = None,
+        crop_pct: Optional[float] = None,
+        resample: PILImageResampling = PILImageResampling.BILINEAR,
+        do_rescale: bool = True,
+        rescale_factor: Union[int, float] = 1 / 255,
+        do_normalize: bool = True,
+        image_mean: Optional[Union[float, list[float]]] = None,
+        image_std: Optional[Union[float, list[float]]] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        size = size if size is not None else {"shortest_edge": 384}
+        size = get_size_dict(size, default_to_square=False)
+
+        self.do_resize = do_resize
+        self.size = size
+        # Default value set here for backwards compatibility where the value in config is None
+        self.crop_pct = crop_pct if crop_pct is not None else 224 / 256
+        self.resample = resample
+        self.do_rescale = do_rescale
+        self.rescale_factor = rescale_factor
+        self.do_normalize = do_normalize
+        self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+        self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+    def resize(
+        self,
+        image: np.ndarray,
+        size: dict[str, int],
+        crop_pct: float,
+        resample: PILImageResampling = PILImageResampling.BICUBIC,
+        data_format: Optional[Union[str, ChannelDimension]] = None,
+        input_data_format: Optional[Union[str, ChannelDimension]] = None,
+        **kwargs,
+    ) -> np.ndarray:
+        """
+        Resize an image.
+
+        Args:
+            image (`np.ndarray`):
+                Image to resize.
+            size (`dict[str, int]`):
+                Dictionary of the form `{"shortest_edge": int}`, specifying the size of the output image. If
+                `size["shortest_edge"]` >= 384 image is resized to `(size["shortest_edge"], size["shortest_edge"])`.
+                Otherwise, the smaller edge of the image will be matched to `int(size["shortest_edge"] / crop_pct)`,
+                after which the image is cropped to `(size["shortest_edge"], size["shortest_edge"])`.
+            crop_pct (`float`):
+                Percentage of the image to crop. Only has an effect if size < 384.
+            resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+                Resampling filter to use when resizing the image.
+            data_format (`str` or `ChannelDimension`, *optional*):
+                The channel dimension format of the image. If not provided, it will be the same as the input image.
+            input_data_format (`ChannelDimension` or `str`, *optional*):
+                The channel dimension format of the input image. If not provided, it will be inferred from the input
+                image.
+        """
+        size = get_size_dict(size, default_to_square=False)
+        if "shortest_edge" not in size:
+            raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {size.keys()}")
+        shortest_edge = size["shortest_edge"]
+
+        if shortest_edge < 384:
+            # maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
+            resize_shortest_edge = int(shortest_edge / crop_pct)
+            resize_size = get_resize_output_image_size(
+                image, size=resize_shortest_edge, default_to_square=False, input_data_format=input_data_format
+            )
+            image = resize(
+                image=image,
+                size=resize_size,
+                resample=resample,
+                data_format=data_format,
+                input_data_format=input_data_format,
+                **kwargs,
+            )
+            # then crop to (shortest_edge, shortest_edge)
+            return center_crop(
+                image=image,
+                size=(shortest_edge, shortest_edge),
+                data_format=data_format,
+                input_data_format=input_data_format,
+                **kwargs,
+            )
+        else:
+            # warping (no cropping) when evaluated at 384 or larger
+            return resize(
+                image,
+                size=(shortest_edge, shortest_edge),
+                resample=resample,
+                data_format=data_format,
+                input_data_format=input_data_format,
+                **kwargs,
+            )
+
+    @filter_out_non_signature_kwargs()
+    def preprocess(
+        self,
+        images: ImageInput,
+        do_resize: Optional[bool] = None,
+        size: Optional[dict[str, int]] = None,
+        crop_pct: Optional[float] = None,
+        resample: Optional[PILImageResampling] = None,
+        do_rescale: Optional[bool] = None,
+        rescale_factor: Optional[float] = None,
+        do_normalize: Optional[bool] = None,
+        image_mean: Optional[Union[float, list[float]]] = None,
+        image_std: Optional[Union[float, list[float]]] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        data_format: ChannelDimension = ChannelDimension.FIRST,
+        input_data_format: Optional[Union[str, ChannelDimension]] = None,
+    ) -> PIL.Image.Image:
+        """
+        Preprocess an image or batch of images.
+
+        Args:
+            images (`ImageInput`):
+                Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+                passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+            do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+                Whether to resize the image.
+            size (`dict[str, int]`, *optional*, defaults to `self.size`):
+                Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image
+                is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the
+                image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to
+                `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`.
+            crop_pct (`float`, *optional*, defaults to `self.crop_pct`):
+                Percentage of the image to crop if size < 384.
+            resample (`int`, *optional*, defaults to `self.resample`):
+                Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only
+                has an effect if `do_resize` is set to `True`.
+            do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+                Whether to rescale the image values between [0 - 1].
+            rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+                Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+            do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+                Whether to normalize the image.
+            image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+                Image mean.
+            image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+                Image standard deviation.
+            return_tensors (`str` or `TensorType`, *optional*):
+                The type of tensors to return. Can be one of:
+                    - Unset: Return a list of `np.ndarray`.
+                    - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+                    - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+                    - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+                    - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+            data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+                The channel dimension format for the output image. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                - Unset: Use the channel dimension format of the input image.
+            input_data_format (`ChannelDimension` or `str`, *optional*):
+                The channel dimension format for the input image. If unset, the channel dimension format is inferred
+                from the input image. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+        """
+        do_resize = do_resize if do_resize is not None else self.do_resize
+        crop_pct = crop_pct if crop_pct is not None else self.crop_pct
+        resample = resample if resample is not None else self.resample
+        do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+        rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+        do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+        image_mean = image_mean if image_mean is not None else self.image_mean
+        image_std = image_std if image_std is not None else self.image_std
+
+        size = size if size is not None else self.size
+        size = get_size_dict(size, default_to_square=False)
+
+        images = make_flat_list_of_images(images)
+
+        if not valid_images(images):
+            raise ValueError(
+                "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+                "torch.Tensor, tf.Tensor or jax.ndarray."
+            )
+
+        validate_preprocess_arguments(
+            do_rescale=do_rescale,
+            rescale_factor=rescale_factor,
+            do_normalize=do_normalize,
+            image_mean=image_mean,
+            image_std=image_std,
+            do_resize=do_resize,
+            size=size,
+            resample=resample,
+        )
+
+        # All transformations expect numpy arrays.
+        images = [to_numpy_array(image) for image in images]
+
+        if do_rescale and is_scaled_image(images[0]):
+            logger.warning_once(
+                "It looks like you are trying to rescale already rescaled images. If the input"
+                " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+            )
+
+        if input_data_format is None:
+            # We assume that all images have the same channel dimension format.
+            input_data_format = infer_channel_dimension_format(images[0])
+
+        if do_resize:
+            images = [
+                self.resize(
+                    image=image, size=size, crop_pct=crop_pct, resample=resample, input_data_format=input_data_format
+                )
+                for image in images
+            ]
+
+        if do_rescale:
+            images = [
+                self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+                for image in images
+            ]
+
+        if do_normalize:
+            images = [
+                self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+                for image in images
+            ]
+
+        images = [
+            to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+        ]
+
+        data = {"pixel_values": images}
+        return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["ConvNextImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/image_processing_convnext_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/image_processing_convnext_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ab00c0fd091369715e636bac663fbbbdc9239a0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/image_processing_convnext_fast.py
@@ -0,0 +1,180 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for ConvNeXT."""
+
+from typing import Optional, Union
+
+import torch
+from torchvision.transforms.v2 import functional as F
+
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import (
+    BaseImageProcessorFast,
+    DefaultFastImageProcessorKwargs,
+    group_images_by_shape,
+    reorder_images,
+)
+from ...image_transforms import get_resize_output_image_size
+from ...image_utils import (
+    IMAGENET_STANDARD_MEAN,
+    IMAGENET_STANDARD_STD,
+    ChannelDimension,
+    ImageInput,
+    PILImageResampling,
+)
+from ...processing_utils import Unpack
+from ...utils import (
+    TensorType,
+    auto_docstring,
+)
+
+
+class ConvNextFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+    """
+    crop_pct (`float`, *optional*):
+        Percentage of the image to crop. Only has an effect if size < 384. Can be
+        overridden by `crop_pct` in the`preprocess` method.
+    """
+
+    crop_pct: Optional[float]
+
+
+@auto_docstring
+class ConvNextImageProcessorFast(BaseImageProcessorFast):
+    resample = PILImageResampling.BILINEAR
+    image_mean = IMAGENET_STANDARD_MEAN
+    image_std = IMAGENET_STANDARD_STD
+    size = {"shortest_edge": 384}
+    default_to_square = False
+    do_resize = True
+    do_rescale = True
+    do_normalize = True
+    crop_pct = 224 / 256
+    valid_kwargs = ConvNextFastImageProcessorKwargs
+
+    def __init__(self, **kwargs: Unpack[ConvNextFastImageProcessorKwargs]):
+        super().__init__(**kwargs)
+
+    @auto_docstring
+    def preprocess(self, images: ImageInput, **kwargs: Unpack[ConvNextFastImageProcessorKwargs]) -> BatchFeature:
+        return super().preprocess(images, **kwargs)
+
+    def resize(
+        self,
+        image: "torch.Tensor",
+        size: dict[str, int],
+        crop_pct: float,
+        interpolation: PILImageResampling = PILImageResampling.BICUBIC,
+        **kwargs,
+    ) -> "torch.Tensor":
+        """
+        Resize an image.
+
+        Args:
+            image (`torch.Tensor`):
+                Image to resize.
+            size (`dict[str, int]`):
+                Dictionary of the form `{"shortest_edge": int}`, specifying the size of the output image. If
+                `size["shortest_edge"]` >= 384 image is resized to `(size["shortest_edge"], size["shortest_edge"])`.
+                Otherwise, the smaller edge of the image will be matched to `int(size["shortest_edge"] / crop_pct)`,
+                after which the image is cropped to `(size["shortest_edge"], size["shortest_edge"])`.
+            crop_pct (`float`):
+                Percentage of the image to crop. Only has an effect if size < 384.
+            resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+                Resampling filter to use when resizing the image.
+
+        Returns:
+            `torch.Tensor`: Resized image.
+        """
+        if not size.shortest_edge:
+            raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {size.keys()}")
+        shortest_edge = size["shortest_edge"]
+
+        if shortest_edge < 384:
+            # maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
+            resize_shortest_edge = int(shortest_edge / crop_pct)
+            resize_size = get_resize_output_image_size(
+                image, size=resize_shortest_edge, default_to_square=False, input_data_format=ChannelDimension.FIRST
+            )
+            image = F.resize(
+                image,
+                resize_size,
+                interpolation=interpolation,
+                **kwargs,
+            )
+            # then crop to (shortest_edge, shortest_edge)
+            return F.center_crop(
+                image,
+                (shortest_edge, shortest_edge),
+                **kwargs,
+            )
+        else:
+            # warping (no cropping) when evaluated at 384 or larger
+            return F.resize(
+                image,
+                (shortest_edge, shortest_edge),
+                interpolation=interpolation,
+                **kwargs,
+            )
+
+    def _preprocess(
+        self,
+        images: list["torch.Tensor"],
+        do_resize: bool,
+        size: dict[str, int],
+        crop_pct: float,
+        interpolation: Optional["F.InterpolationMode"],
+        do_center_crop: bool,
+        crop_size: int,
+        do_rescale: bool,
+        rescale_factor: float,
+        do_normalize: bool,
+        image_mean: Optional[Union[float, list[float]]],
+        image_std: Optional[Union[float, list[float]]],
+        disable_grouping: Optional[bool],
+        return_tensors: Optional[Union[str, TensorType]],
+        **kwargs,
+    ) -> BatchFeature:
+        # Group images by size for batched resizing
+        grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+        resized_images_grouped = {}
+        for shape, stacked_images in grouped_images.items():
+            if do_resize:
+                stacked_images = self.resize(
+                    image=stacked_images, size=size, crop_pct=crop_pct, interpolation=interpolation
+                )
+            resized_images_grouped[shape] = stacked_images
+        resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+        # Group images by size for further processing
+        # Needed in case do_resize is False, or resize returns images with different sizes
+        grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+        processed_images_grouped = {}
+        for shape, stacked_images in grouped_images.items():
+            if do_center_crop:
+                stacked_images = self.center_crop(stacked_images, crop_size)
+            # Fused rescale and normalize
+            stacked_images = self.rescale_and_normalize(
+                stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+            )
+            processed_images_grouped[shape] = stacked_images
+
+        processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+        processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+        return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
+
+
+__all__ = ["ConvNextImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/modeling_convnext.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/modeling_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..3120c140d2ed41c16e659a780c7d4603dcfda7b7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/modeling_convnext.py
@@ -0,0 +1,424 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ConvNext model."""
+
+from typing import Optional
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import (
+    BackboneOutput,
+    BaseModelOutputWithNoAttention,
+    BaseModelOutputWithPoolingAndNoAttention,
+    ImageClassifierOutputWithNoAttention,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, logging
+from ...utils.backbone_utils import BackboneMixin
+from ...utils.generic import can_return_tuple
+from .configuration_convnext import ConvNextConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+    """
+    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+    argument.
+    """
+    if drop_prob == 0.0 or not training:
+        return input
+    keep_prob = 1 - drop_prob
+    shape = (input.shape[0],) + (1,) * (input.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+    random_tensor.floor_()  # binarize
+    output = input.div(keep_prob) * random_tensor
+    return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNext
+class ConvNextDropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+    def __init__(self, drop_prob: Optional[float] = None) -> None:
+        super().__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        return drop_path(hidden_states, self.drop_prob, self.training)
+
+    def extra_repr(self) -> str:
+        return f"p={self.drop_prob}"
+
+
+class ConvNextLayerNorm(nn.LayerNorm):
+    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+    width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+    """
+
+    def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
+        super().__init__(normalized_shape, eps=eps, **kwargs)
+        if data_format not in ["channels_last", "channels_first"]:
+            raise NotImplementedError(f"Unsupported data format: {data_format}")
+        self.data_format = data_format
+
+    def forward(self, features: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
+        """
+        if self.data_format == "channels_first":
+            features = features.permute(0, 2, 3, 1)
+            features = super().forward(features)
+            features = features.permute(0, 3, 1, 2)
+        else:
+            features = super().forward(features)
+        return features
+
+
+class ConvNextEmbeddings(nn.Module):
+    """This class is comparable to (and inspired by) the SwinEmbeddings class
+    found in src/transformers/models/swin/modeling_swin.py.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        self.patch_embeddings = nn.Conv2d(
+            config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
+        )
+        self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
+        self.num_channels = config.num_channels
+
+    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+        num_channels = pixel_values.shape[1]
+        if num_channels != self.num_channels:
+            raise ValueError(
+                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+            )
+        embeddings = self.patch_embeddings(pixel_values)
+        embeddings = self.layernorm(embeddings)
+        return embeddings
+
+
+class ConvNextLayer(nn.Module):
+    """This corresponds to the `Block` class in the original implementation.
+
+    There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
+    H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
+
+    The authors used (2) as they find it slightly faster in PyTorch.
+
+    Args:
+        config ([`ConvNextConfig`]): Model configuration class.
+        dim (`int`): Number of input channels.
+        drop_path (`float`): Stochastic depth rate. Default: 0.0.
+    """
+
+    def __init__(self, config, dim, drop_path=0):
+        super().__init__()
+        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv
+        self.layernorm = ConvNextLayerNorm(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(dim, 4 * dim)  # pointwise/1x1 convs, implemented with linear layers
+        self.act = ACT2FN[config.hidden_act]
+        self.pwconv2 = nn.Linear(4 * dim, dim)
+        self.layer_scale_parameter = (
+            nn.Parameter(config.layer_scale_init_value * torch.ones(dim), requires_grad=True)
+            if config.layer_scale_init_value > 0
+            else None
+        )
+        self.drop_path = ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+    def forward(self, features: torch.Tensor) -> torch.Tensor:
+        residual = features
+        features = self.dwconv(features)
+        features = features.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
+        features = self.layernorm(features)
+        features = self.pwconv1(features)
+        features = self.act(features)
+        features = self.pwconv2(features)
+        if self.layer_scale_parameter is not None:
+            features = self.layer_scale_parameter * features
+        features = features.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)
+        features = residual + self.drop_path(features)
+        return features
+
+
+class ConvNextStage(nn.Module):
+    """ConvNeXT stage, consisting of an optional downsampling layer + multiple residual blocks.
+
+    Args:
+        config ([`ConvNextConfig`]): Model configuration class.
+        in_channels (`int`): Number of input channels.
+        out_channels (`int`): Number of output channels.
+        depth (`int`): Number of residual blocks.
+        drop_path_rates(`list[float]`): Stochastic depth rates for each layer.
+    """
+
+    def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None):
+        super().__init__()
+
+        if in_channels != out_channels or stride > 1:
+            self.downsampling_layer = nn.ModuleList(
+                [
+                    ConvNextLayerNorm(in_channels, eps=1e-6, data_format="channels_first"),
+                    nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
+                ]
+            )
+        else:
+            self.downsampling_layer = nn.ModuleList()
+        drop_path_rates = drop_path_rates or [0.0] * depth
+        self.layers = nn.ModuleList(
+            [ConvNextLayer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
+        )
+
+    def forward(self, features: torch.Tensor) -> torch.Tensor:
+        for layer in self.downsampling_layer:
+            features = layer(features)
+        for layer in self.layers:
+            features = layer(features)
+        return features
+
+
+class ConvNextEncoder(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.stages = nn.ModuleList()
+        drop_path_rates = [
+            x.tolist()
+            for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").split(config.depths)
+        ]
+        prev_chs = config.hidden_sizes[0]
+        for i in range(config.num_stages):
+            out_chs = config.hidden_sizes[i]
+            stage = ConvNextStage(
+                config,
+                in_channels=prev_chs,
+                out_channels=out_chs,
+                stride=2 if i > 0 else 1,
+                depth=config.depths[i],
+                drop_path_rates=drop_path_rates[i],
+            )
+            self.stages.append(stage)
+            prev_chs = out_chs
+
+    def forward(
+        self, hidden_states: torch.Tensor, output_hidden_states: Optional[bool] = False
+    ) -> BaseModelOutputWithNoAttention:
+        all_hidden_states = [hidden_states] if output_hidden_states else None
+
+        for layer_module in self.stages:
+            hidden_states = layer_module(hidden_states)
+            if all_hidden_states is not None:
+                all_hidden_states.append(hidden_states)
+
+        return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
+
+
+@auto_docstring
+class ConvNextPreTrainedModel(PreTrainedModel):
+    config: ConvNextConfig
+    base_model_prefix = "convnext"
+    main_input_name = "pixel_values"
+    _no_split_modules = ["ConvNextLayer"]
+    _can_record_outputs = {}  # hidden states are collected explicitly
+
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, (nn.LayerNorm, ConvNextLayerNorm)):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, ConvNextLayer):
+            if module.layer_scale_parameter is not None:
+                module.layer_scale_parameter.data.fill_(self.config.layer_scale_init_value)
+
+
+@auto_docstring
+class ConvNextModel(ConvNextPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.config = config
+
+        self.embeddings = ConvNextEmbeddings(config)
+        self.encoder = ConvNextEncoder(config)
+
+        # final layernorm layer
+        self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @can_return_tuple
+    @auto_docstring
+    def forward(
+        self, pixel_values: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None
+    ) -> BaseModelOutputWithPoolingAndNoAttention:
+        if output_hidden_states is None:
+            output_hidden_states = self.config.output_hidden_states
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        embedding_output = self.embeddings(pixel_values)
+        encoder_outputs: BaseModelOutputWithNoAttention = self.encoder(
+            embedding_output, output_hidden_states=output_hidden_states
+        )
+        last_hidden_state = encoder_outputs.last_hidden_state
+
+        # global average pooling, (N, C, H, W) -> (N, C)
+        pooled_output = self.layernorm(last_hidden_state.mean([-2, -1]))
+
+        return BaseModelOutputWithPoolingAndNoAttention(
+            last_hidden_state=last_hidden_state,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+        )
+
+
+@auto_docstring(
+    custom_intro="""
+    ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+    ImageNet.
+    """
+)
+class ConvNextForImageClassification(ConvNextPreTrainedModel):
+    accepts_loss_kwargs = False
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.num_labels = config.num_labels
+        self.convnext = ConvNextModel(config)
+
+        # Classifier head
+        if config.num_labels > 0:
+            self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels)
+        else:
+            self.classifier = nn.Identity()
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @can_return_tuple
+    @auto_docstring
+    def forward(
+        self, pixel_values: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, **kwargs
+    ) -> ImageClassifierOutputWithNoAttention:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+        outputs: BaseModelOutputWithPoolingAndNoAttention = self.convnext(pixel_values, **kwargs)
+        pooled_output = outputs.pooler_output
+        logits = self.classifier(pooled_output)
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(labels=labels, pooled_logits=logits, config=self.config)
+
+        return ImageClassifierOutputWithNoAttention(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+        )
+
+
+@auto_docstring(
+    custom_intro="""
+    ConvNeXt backbone, to be used with frameworks like DETR and MaskFormer.
+    """
+)
+class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
+    has_attentions = False
+
+    def __init__(self, config):
+        super().__init__(config)
+        super()._init_backbone(config)
+
+        self.embeddings = ConvNextEmbeddings(config)
+        self.encoder = ConvNextEncoder(config)
+        self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
+
+        # Add layer norms to hidden states of out_features
+        hidden_states_norms = {}
+        for stage, num_channels in zip(self._out_features, self.channels):
+            hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first")
+        self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
+
+        # initialize weights and apply final processing
+        self.post_init()
+
+    @can_return_tuple
+    @auto_docstring
+    def forward(
+        self,
+        pixel_values: torch.Tensor,
+        output_hidden_states: Optional[bool] = None,
+    ) -> BackboneOutput:
+        r"""
+        Examples:
+
+        ```python
+        >>> from transformers import AutoImageProcessor, AutoBackbone
+        >>> import torch
+        >>> from PIL import Image
+        >>> import requests
+
+        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw)
+
+        >>> processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
+        >>> model = AutoBackbone.from_pretrained("facebook/convnext-tiny-224")
+
+        >>> inputs = processor(image, return_tensors="pt")
+        >>> outputs = model(**inputs)
+        ```"""
+        if output_hidden_states is None:
+            output_hidden_states = self.config.output_hidden_states
+
+        embedding_output = self.embeddings(pixel_values)
+        outputs: BaseModelOutputWithPoolingAndNoAttention = self.encoder(embedding_output, output_hidden_states=True)
+        hidden_states = outputs.hidden_states
+
+        feature_maps = []
+        for stage, hidden_state in zip(self.stage_names, hidden_states):
+            if stage in self.out_features:
+                hidden_state = self.hidden_states_norms[stage](hidden_state)
+                feature_maps.append(hidden_state)
+
+        return BackboneOutput(
+            feature_maps=tuple(feature_maps),
+            hidden_states=hidden_states if output_hidden_states else None,
+        )
+
+
+__all__ = ["ConvNextForImageClassification", "ConvNextModel", "ConvNextPreTrainedModel", "ConvNextBackbone"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/modeling_tf_convnext.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/modeling_tf_convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..7306877466d9b793682d01d90645f08420930b59
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/convnext/modeling_tf_convnext.py
@@ -0,0 +1,667 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 ConvNext model."""
+
+from __future__ import annotations
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
+from ...modeling_tf_utils import (
+    TFModelInputType,
+    TFPreTrainedModel,
+    TFSequenceClassificationLoss,
+    get_initializer,
+    keras,
+    keras_serializable,
+    unpack_inputs,
+)
+from ...tf_utils import shape_list
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from .configuration_convnext import ConvNextConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_CONFIG_FOR_DOC = "ConvNextConfig"
+_CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224"
+
+
+class TFConvNextDropPath(keras.layers.Layer):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+    References:
+        (1) github.com:rwightman/pytorch-image-models
+    """
+
+    def __init__(self, drop_path: float, **kwargs):
+        super().__init__(**kwargs)
+        self.drop_path = drop_path
+
+    def call(self, x: tf.Tensor, training=None):
+        if training:
+            keep_prob = 1 - self.drop_path
+            shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
+            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
+            random_tensor = tf.floor(random_tensor)
+            return (x / keep_prob) * random_tensor
+        return x
+
+
+class TFConvNextEmbeddings(keras.layers.Layer):
+    """This class is comparable to (and inspired by) the SwinEmbeddings class
+    found in src/transformers/models/swin/modeling_swin.py.
+    """
+
+    def __init__(self, config: ConvNextConfig, **kwargs):
+        super().__init__(**kwargs)
+        self.patch_embeddings = keras.layers.Conv2D(
+            filters=config.hidden_sizes[0],
+            kernel_size=config.patch_size,
+            strides=config.patch_size,
+            name="patch_embeddings",
+            kernel_initializer=get_initializer(config.initializer_range),
+            bias_initializer=keras.initializers.Zeros(),
+        )
+        self.layernorm = keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
+        self.num_channels = config.num_channels
+        self.config = config
+
+    def call(self, pixel_values):
+        if isinstance(pixel_values, dict):
+            pixel_values = pixel_values["pixel_values"]
+
+        tf.debugging.assert_equal(
+            shape_list(pixel_values)[1],
+            self.num_channels,
+            message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.",
+        )
+
+        # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
+        # So change the input format from `NCHW` to `NHWC`.
+        # shape = (batch_size, in_height, in_width, in_channels)
+        pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
+
+        embeddings = self.patch_embeddings(pixel_values)
+        embeddings = self.layernorm(embeddings)
+        return embeddings
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "patch_embeddings", None) is not None:
+            with tf.name_scope(self.patch_embeddings.name):
+                self.patch_embeddings.build([None, None, None, self.config.num_channels])
+        if getattr(self, "layernorm", None) is not None:
+            with tf.name_scope(self.layernorm.name):
+                self.layernorm.build([None, None, None, self.config.hidden_sizes[0]])
+
+
+class TFConvNextLayer(keras.layers.Layer):
+    """This corresponds to the `Block` class in the original implementation.
+
+    There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
+    H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
+
+    The authors used (2) as they find it slightly faster in PyTorch. Since we already permuted the inputs to follow
+    NHWC ordering, we can just apply the operations straight-away without the permutation.
+
+    Args:
+        config ([`ConvNextConfig`]): Model configuration class.
+        dim (`int`): Number of input channels.
+        drop_path (`float`): Stochastic depth rate. Default: 0.0.
+    """
+
+    def __init__(self, config, dim, drop_path=0.0, **kwargs):
+        super().__init__(**kwargs)
+        self.dim = dim
+        self.config = config
+        self.dwconv = keras.layers.Conv2D(
+            filters=dim,
+            kernel_size=7,
+            padding="same",
+            groups=dim,
+            kernel_initializer=get_initializer(config.initializer_range),
+            bias_initializer="zeros",
+            name="dwconv",
+        )  # depthwise conv
+        self.layernorm = keras.layers.LayerNormalization(
+            epsilon=1e-6,
+            name="layernorm",
+        )
+        self.pwconv1 = keras.layers.Dense(
+            units=4 * dim,
+            kernel_initializer=get_initializer(config.initializer_range),
+            bias_initializer="zeros",
+            name="pwconv1",
+        )  # pointwise/1x1 convs, implemented with linear layers
+        self.act = get_tf_activation(config.hidden_act)
+        self.pwconv2 = keras.layers.Dense(
+            units=dim,
+            kernel_initializer=get_initializer(config.initializer_range),
+            bias_initializer="zeros",
+            name="pwconv2",
+        )
+        # Using `layers.Activation` instead of `tf.identity` to better control `training`
+        # behaviour.
+        self.drop_path = (
+            TFConvNextDropPath(drop_path, name="drop_path")
+            if drop_path > 0.0
+            else keras.layers.Activation("linear", name="drop_path")
+        )
+
+    def build(self, input_shape: tf.TensorShape = None):
+        # PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa)
+        self.layer_scale_parameter = (
+            self.add_weight(
+                shape=(self.dim,),
+                initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
+                trainable=True,
+                name="layer_scale_parameter",
+            )
+            if self.config.layer_scale_init_value > 0
+            else None
+        )
+
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "dwconv", None) is not None:
+            with tf.name_scope(self.dwconv.name):
+                self.dwconv.build([None, None, None, self.dim])
+        if getattr(self, "layernorm", None) is not None:
+            with tf.name_scope(self.layernorm.name):
+                self.layernorm.build([None, None, None, self.dim])
+        if getattr(self, "pwconv1", None) is not None:
+            with tf.name_scope(self.pwconv1.name):
+                self.pwconv1.build([None, None, self.dim])
+        if getattr(self, "pwconv2", None) is not None:
+            with tf.name_scope(self.pwconv2.name):
+                self.pwconv2.build([None, None, 4 * self.dim])
+        if getattr(self, "drop_path", None) is not None:
+            with tf.name_scope(self.drop_path.name):
+                self.drop_path.build(None)
+
+    def call(self, hidden_states, training=False):
+        input = hidden_states
+        x = self.dwconv(hidden_states)
+        x = self.layernorm(x)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.pwconv2(x)
+
+        if self.layer_scale_parameter is not None:
+            x = self.layer_scale_parameter * x
+
+        x = input + self.drop_path(x, training=training)
+        return x
+
+
+class TFConvNextStage(keras.layers.Layer):
+    """ConvNext stage, consisting of an optional downsampling layer + multiple residual blocks.
+
+    Args:
+        config (`ConvNextV2Config`):
+            Model configuration class.
+        in_channels (`int`):
+            Number of input channels.
+        out_channels (`int`):
+            Number of output channels.
+        depth (`int`):
+            Number of residual blocks.
+        drop_path_rates(`list[float]`):
+            Stochastic depth rates for each layer.
+    """
+
+    def __init__(
+        self,
+        config: ConvNextConfig,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int = 2,
+        stride: int = 2,
+        depth: int = 2,
+        drop_path_rates: list[float] | None = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        if in_channels != out_channels or stride > 1:
+            self.downsampling_layer = [
+                keras.layers.LayerNormalization(
+                    epsilon=1e-6,
+                    name="downsampling_layer.0",
+                ),
+                # Inputs to this layer will follow NHWC format since we
+                # transposed the inputs from NCHW to NHWC in the `TFConvNextEmbeddings`
+                # layer. All the outputs throughout the model will be in NHWC
+                # from this point on until the output where we again change to
+                # NCHW.
+                keras.layers.Conv2D(
+                    filters=out_channels,
+                    kernel_size=kernel_size,
+                    strides=stride,
+                    kernel_initializer=get_initializer(config.initializer_range),
+                    bias_initializer=keras.initializers.Zeros(),
+                    name="downsampling_layer.1",
+                ),
+            ]
+        else:
+            self.downsampling_layer = [tf.identity]
+
+        drop_path_rates = drop_path_rates or [0.0] * depth
+        self.layers = [
+            TFConvNextLayer(
+                config,
+                dim=out_channels,
+                drop_path=drop_path_rates[j],
+                name=f"layers.{j}",
+            )
+            for j in range(depth)
+        ]
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.stride = stride
+
+    def call(self, hidden_states):
+        for layer in self.downsampling_layer:
+            hidden_states = layer(hidden_states)
+        for layer in self.layers:
+            hidden_states = layer(hidden_states)
+        return hidden_states
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "layers", None) is not None:
+            for layer in self.layers:
+                with tf.name_scope(layer.name):
+                    layer.build(None)
+        if self.in_channels != self.out_channels or self.stride > 1:
+            with tf.name_scope(self.downsampling_layer[0].name):
+                self.downsampling_layer[0].build([None, None, None, self.in_channels])
+            with tf.name_scope(self.downsampling_layer[1].name):
+                self.downsampling_layer[1].build([None, None, None, self.in_channels])
+
+
+class TFConvNextEncoder(keras.layers.Layer):
+    def __init__(self, config, **kwargs):
+        super().__init__(**kwargs)
+        self.stages = []
+        drop_path_rates = tf.linspace(0.0, config.drop_path_rate, sum(config.depths))
+        drop_path_rates = tf.split(drop_path_rates, config.depths)
+        drop_path_rates = [x.numpy().tolist() for x in drop_path_rates]
+        prev_chs = config.hidden_sizes[0]
+        for i in range(config.num_stages):
+            out_chs = config.hidden_sizes[i]
+            stage = TFConvNextStage(
+                config,
+                in_channels=prev_chs,
+                out_channels=out_chs,
+                stride=2 if i > 0 else 1,
+                depth=config.depths[i],
+                drop_path_rates=drop_path_rates[i],
+                name=f"stages.{i}",
+            )
+            self.stages.append(stage)
+            prev_chs = out_chs
+
+    def call(self, hidden_states, output_hidden_states=False, return_dict=True):
+        all_hidden_states = () if output_hidden_states else None
+
+        for i, layer_module in enumerate(self.stages):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            hidden_states = layer_module(hidden_states)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
+
+        return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
+
+    def build(self, input_shape=None):
+        for stage in self.stages:
+            with tf.name_scope(stage.name):
+                stage.build(None)
+
+
+@keras_serializable
+class TFConvNextMainLayer(keras.layers.Layer):
+    config_class = ConvNextConfig
+
+    def __init__(self, config: ConvNextConfig, add_pooling_layer: bool = True, **kwargs):
+        super().__init__(**kwargs)
+
+        self.config = config
+        self.embeddings = TFConvNextEmbeddings(config, name="embeddings")
+        self.encoder = TFConvNextEncoder(config, name="encoder")
+        self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+        # We are setting the `data_format` like so because from here on we will revert to the
+        # NCHW output format
+        self.pooler = keras.layers.GlobalAvgPool2D(data_format="channels_first") if add_pooling_layer else None
+
+    @unpack_inputs
+    def call(
+        self,
+        pixel_values: TFModelInputType | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        training: bool = False,
+    ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        embedding_output = self.embeddings(pixel_values, training=training)
+
+        encoder_outputs = self.encoder(
+            embedding_output,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+
+        last_hidden_state = encoder_outputs[0]
+        # Change to NCHW output format have uniformity in the modules
+        last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
+        pooled_output = self.layernorm(self.pooler(last_hidden_state))
+
+        # Change the other hidden state outputs to NCHW as well
+        if output_hidden_states:
+            hidden_states = tuple(tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1])
+
+        if not return_dict:
+            hidden_states = hidden_states if output_hidden_states else ()
+            return (last_hidden_state, pooled_output) + hidden_states
+
+        return TFBaseModelOutputWithPooling(
+            last_hidden_state=last_hidden_state,
+            pooler_output=pooled_output,
+            hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "embeddings", None) is not None:
+            with tf.name_scope(self.embeddings.name):
+                self.embeddings.build(None)
+        if getattr(self, "encoder", None) is not None:
+            with tf.name_scope(self.encoder.name):
+                self.encoder.build(None)
+        if getattr(self, "layernorm", None) is not None:
+            with tf.name_scope(self.layernorm.name):
+                self.layernorm.build([None, self.config.hidden_sizes[-1]])
+
+
+class TFConvNextPreTrainedModel(TFPreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = ConvNextConfig
+    base_model_prefix = "convnext"
+    main_input_name = "pixel_values"
+
+
+CONVNEXT_START_DOCSTRING = r"""
+    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+    as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+    behavior.
+
+    
+
+    TensorFlow models and layers in `transformers` accept two formats as input:
+
+    - having all inputs as keyword arguments (like PyTorch models), or
+    - having all inputs as a list, tuple or dict in the first positional argument.
+
+    The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+    and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+    pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+    format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+    the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+    positional argument:
+
+    - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
+    - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+    `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
+    - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+    `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
+
+    Note that when creating models and layers with
+    [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+    about any of this, as you can just pass inputs like you would to any other Python function!
+
+    
+
+    Parameters:
+        config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+CONVNEXT_INPUTS_DOCSTRING = r"""
+    Args:
+        pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
+            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+            [`ConvNextImageProcessor.__call__`] for details.
+
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+            used instead.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+            eager mode, in graph mode the value will always be set to True.
+"""
+
+
+@add_start_docstrings(
+    "The bare ConvNext model outputting raw features without any specific head on top.",
+    CONVNEXT_START_DOCSTRING,
+)
+class TFConvNextModel(TFConvNextPreTrainedModel):
+    def __init__(self, config, *inputs, add_pooling_layer=True, **kwargs):
+        super().__init__(config, *inputs, **kwargs)
+        self.convnext = TFConvNextMainLayer(config, add_pooling_layer=add_pooling_layer, name="convnext")
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
+    def call(
+        self,
+        pixel_values: TFModelInputType | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        training: bool = False,
+    ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
+        r"""
+        Returns:
+
+        Examples:
+
+        ```python
+        >>> from transformers import AutoImageProcessor, TFConvNextModel
+        >>> from PIL import Image
+        >>> import requests
+
+        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw)
+
+        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
+        >>> model = TFConvNextModel.from_pretrained("facebook/convnext-tiny-224")
+
+        >>> inputs = image_processor(images=image, return_tensors="tf")
+        >>> outputs = model(**inputs)
+        >>> last_hidden_states = outputs.last_hidden_state
+        ```"""
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        outputs = self.convnext(
+            pixel_values=pixel_values,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+
+        if not return_dict:
+            return (outputs[0],) + outputs[1:]
+
+        return TFBaseModelOutputWithPooling(
+            last_hidden_state=outputs.last_hidden_state,
+            pooler_output=outputs.pooler_output,
+            hidden_states=outputs.hidden_states,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "convnext", None) is not None:
+            with tf.name_scope(self.convnext.name):
+                self.convnext.build(None)
+
+
+@add_start_docstrings(
+    """
+    ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+    ImageNet.
+    """,
+    CONVNEXT_START_DOCSTRING,
+)
+class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClassificationLoss):
+    def __init__(self, config: ConvNextConfig, *inputs, **kwargs):
+        super().__init__(config, *inputs, **kwargs)
+
+        self.num_labels = config.num_labels
+        self.convnext = TFConvNextMainLayer(config, name="convnext")
+
+        # Classifier head
+        self.classifier = keras.layers.Dense(
+            units=config.num_labels,
+            kernel_initializer=get_initializer(config.initializer_range),
+            bias_initializer="zeros",
+            name="classifier",
+        )
+        self.config = config
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
+    def call(
+        self,
+        pixel_values: TFModelInputType | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        labels: np.ndarray | tf.Tensor | None = None,
+        training: bool | None = False,
+    ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]:
+        r"""
+        labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
+            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+        Returns:
+
+        Examples:
+
+        ```python
+        >>> from transformers import AutoImageProcessor, TFConvNextForImageClassification
+        >>> import tensorflow as tf
+        >>> from PIL import Image
+        >>> import requests
+
+        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw)
+
+        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
+        >>> model = TFConvNextForImageClassification.from_pretrained("facebook/convnext-tiny-224")
+
+        >>> inputs = image_processor(images=image, return_tensors="tf")
+        >>> outputs = model(**inputs)
+        >>> logits = outputs.logits
+        >>> # model predicts one of the 1000 ImageNet classes
+        >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
+        >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
+        ```"""
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        outputs = self.convnext(
+            pixel_values,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+
+        pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+        logits = self.classifier(pooled_output)
+        loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
+
+        if not return_dict:
+            output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return TFSequenceClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "convnext", None) is not None:
+            with tf.name_scope(self.convnext.name):
+                self.convnext.build(None)
+        if getattr(self, "classifier", None) is not None:
+            if hasattr(self.classifier, "name"):
+                with tf.name_scope(self.classifier.name):
+                    self.classifier.build([None, None, self.config.hidden_sizes[-1]])
+
+
+__all__ = ["TFConvNextForImageClassification", "TFConvNextModel", "TFConvNextPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7000ac3d353bf4eba157b07e350b9ac5f7552a98
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_data2vec_audio import *
+    from .configuration_data2vec_text import *
+    from .configuration_data2vec_vision import *
+    from .modeling_data2vec_audio import *
+    from .modeling_data2vec_text import *
+    from .modeling_data2vec_vision import *
+    from .modeling_tf_data2vec_vision import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_audio.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d88a9de6543afaa60c6d5353f2c32d871d5ee21
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_audio.py
@@ -0,0 +1,288 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Data2VecText configuration"""
+
+import math
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Data2VecAudioConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`Data2VecAudioModel`]. It is used to instantiate
+    an Data2VecAudio model according to the specified arguments, defining the model architecture. Instantiating a
+    configuration with the defaults will yield a similar configuration to that of the Data2VecAudio
+    [facebook/data2vec-audio-base-960h](https://huggingface.co/facebook/data2vec-audio-base-960h) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 32):
+            Vocabulary size of the Data2VecAudio model. Defines the number of different tokens that can be represented
+            by the `inputs_ids` passed when calling [`Data2VecAudioModel`] or [`TFData2VecAudioModel`]. Vocabulary size
+            of the model. Defines the different tokens that can be represented by the *inputs_ids* passed to the
+            forward method of [`Data2VecAudioModel`].
+        hidden_size (`int`, *optional*, defaults to 768):
+            Dimensionality of the encoder layers and the pooler layer.
+        num_hidden_layers (`int`, *optional*, defaults to 12):
+            Number of hidden layers in the Transformer encoder.
+        num_attention_heads (`int`, *optional*, defaults to 12):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        intermediate_size (`int`, *optional*, defaults to 3072):
+            Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"selu"` and `"gelu_new"` are supported.
+        hidden_dropout (`float`, *optional*, defaults to 0.1):
+            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+        activation_dropout (`float`, *optional*, defaults to 0.1):
+            The dropout ratio for activations inside the fully connected layer.
+        attention_dropout (`float`, *optional*, defaults to 0.1):
+            The dropout ratio for the attention probabilities.
+        final_dropout (`float`, *optional*, defaults to 0.1):
+            The dropout probability for the final projection layer of [`Data2VecAudioForCTC`].
+        layerdrop (`float`, *optional*, defaults to 0.1):
+            The LayerDrop probability. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556) for more
+            details.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+            The epsilon used by the layer normalization layers.
+        feat_proj_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout probability for output of the feature encoder.
+        feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the 1D convolutional layers of the feature
+            extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
+        conv_dim (`tuple[int]` or `list[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
+            A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
+            feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
+        conv_stride (`tuple[int]` or `list[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
+            A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
+            of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
+        conv_kernel (`tuple[int]` or `list[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
+            A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
+            length of *conv_kernel* defines the number of convolutional layers and has to match the length of
+            *conv_dim*.
+        conv_bias (`bool`, *optional*, defaults to `False`):
+            Whether the 1D convolutional layers have a bias.
+        num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
+            Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
+            embeddings layer.
+        num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
+            Number of groups of 1D convolutional positional embeddings layer.
+        mask_time_prob (`float`, *optional*, defaults to 0.05):
+            Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
+            procedure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
+            reasoning from the probability of each feature vector to be chosen as the start of the vector span to be
+            masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
+        mask_time_length (`int`, *optional*, defaults to 10):
+            Length of vector span along the time axis.
+        mask_time_min_masks (`int`, *optional*, defaults to 2),:
+            The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
+            irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
+            mask_time_min_masks''
+        mask_feature_prob (`float`, *optional*, defaults to 0.0):
+            Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
+            masking procedure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
+            the axis. If reasoning from the probability of each feature vector to be chosen as the start of the vector
+            span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
+            may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
+            True`.
+        mask_feature_length (`int`, *optional*, defaults to 10):
+            Length of vector span along the feature axis.
+        mask_feature_min_masks (`int`, *optional*, defaults to 0),:
+            The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
+            step, irrespectively of `mask_feature_prob`. Only relevant if
+            ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
+        ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
+            Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
+            instance of [`Data2VecAudioForCTC`].
+        ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
+            Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
+            occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
+            of [`Data2VecAudioForCTC`].
+        use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
+            Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
+            instance of [`Data2VecAudioForSequenceClassification`].
+        classifier_proj_size (`int`, *optional*, defaults to 256):
+            Dimensionality of the projection before token mean-pooling for classification.
+        tdnn_dim (`tuple[int]` or `list[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
+            A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
+            module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
+        tdnn_kernel (`tuple[int]` or `list[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
+            A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
+            *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
+        tdnn_dilation (`tuple[int]` or `list[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
+            A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
+            *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
+        xvector_output_dim (`int`, *optional*, defaults to 512):
+            Dimensionality of the *XVector* embedding vectors.
+        add_adapter (`bool`, *optional*, defaults to `False`):
+            Whether a convolutional network should be stacked on top of the Data2VecAudio Encoder. Can be very useful
+            for warm-starting Data2VecAudio for SpeechEncoderDecoder models.
+        adapter_kernel_size (`int`, *optional*, defaults to 3):
+            Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
+        adapter_stride (`int`, *optional*, defaults to 2):
+            Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
+        num_adapter_layers (`int`, *optional*, defaults to 3):
+            Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
+            True`.
+        output_hidden_size (`int`, *optional*):
+            Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
+            if `add_adapter is True`.
+
+    Example:
+
+    ```python
+    >>> from transformers import Data2VecAudioConfig, Data2VecAudioModel
+
+    >>> # Initializing a Data2VecAudio facebook/data2vec-audio-base-960h style configuration
+    >>> configuration = Data2VecAudioConfig()
+
+    >>> # Initializing a model (with random weights) from the facebook/data2vec-audio-base-960h style configuration
+    >>> model = Data2VecAudioModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "data2vec-audio"
+
+    def __init__(
+        self,
+        vocab_size=32,
+        hidden_size=768,
+        num_hidden_layers=12,
+        num_attention_heads=12,
+        intermediate_size=3072,
+        hidden_act="gelu",
+        hidden_dropout=0.1,
+        activation_dropout=0.1,
+        attention_dropout=0.1,
+        feat_proj_dropout=0.0,
+        final_dropout=0.1,
+        layerdrop=0.1,
+        initializer_range=0.02,
+        layer_norm_eps=1e-5,
+        feat_extract_activation="gelu",
+        conv_dim=(512, 512, 512, 512, 512, 512, 512),
+        conv_stride=(5, 2, 2, 2, 2, 2, 2),
+        conv_kernel=(10, 3, 3, 3, 3, 2, 2),
+        conv_bias=False,
+        num_conv_pos_embedding_groups=16,
+        conv_pos_kernel_size=19,
+        num_conv_pos_embeddings=5,
+        mask_time_prob=0.05,
+        mask_time_length=10,
+        mask_time_min_masks=2,
+        mask_feature_prob=0.0,
+        mask_feature_length=10,
+        mask_feature_min_masks=0,
+        ctc_loss_reduction="sum",
+        ctc_zero_infinity=False,
+        use_weighted_layer_sum=False,
+        classifier_proj_size=256,
+        tdnn_dim=(512, 512, 512, 512, 1500),
+        tdnn_kernel=(5, 3, 3, 1, 1),
+        tdnn_dilation=(1, 2, 3, 1, 1),
+        xvector_output_dim=512,
+        pad_token_id=0,
+        bos_token_id=1,
+        eos_token_id=2,
+        add_adapter=False,
+        adapter_kernel_size=3,
+        adapter_stride=2,
+        num_adapter_layers=3,
+        output_hidden_size=None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
+        self.hidden_size = hidden_size
+        self.feat_extract_activation = feat_extract_activation
+        self.conv_dim = list(conv_dim)
+        self.conv_stride = list(conv_stride)
+        self.conv_kernel = list(conv_kernel)
+        self.conv_bias = conv_bias
+        self.num_conv_pos_embeddings = num_conv_pos_embeddings
+        self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
+        self.conv_pos_kernel_size = conv_pos_kernel_size
+        self.num_feat_extract_layers = len(self.conv_dim)
+        self.num_hidden_layers = num_hidden_layers
+        self.intermediate_size = intermediate_size
+        self.hidden_act = hidden_act
+        self.num_attention_heads = num_attention_heads
+        self.hidden_dropout = hidden_dropout
+        self.attention_dropout = attention_dropout
+        self.activation_dropout = activation_dropout
+        self.feat_proj_dropout = feat_proj_dropout
+        self.final_dropout = final_dropout
+        self.layerdrop = layerdrop
+        self.layer_norm_eps = layer_norm_eps
+        self.initializer_range = initializer_range
+        self.vocab_size = vocab_size
+        self.use_weighted_layer_sum = use_weighted_layer_sum
+
+        if (
+            (len(self.conv_stride) != self.num_feat_extract_layers)
+            or (len(self.conv_kernel) != self.num_feat_extract_layers)
+            or (len(self.conv_dim) != self.num_feat_extract_layers)
+        ):
+            raise ValueError(
+                "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
+                " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
+                f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
+                f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
+            )
+
+        # fine-tuning config parameters for SpecAugment: https://huggingface.co/papers/1904.08779
+        self.mask_time_prob = mask_time_prob
+        self.mask_time_length = mask_time_length
+        self.mask_time_min_masks = mask_time_min_masks
+        self.mask_feature_prob = mask_feature_prob
+        self.mask_feature_length = mask_feature_length
+        self.mask_feature_min_masks = mask_feature_min_masks
+
+        # ctc loss
+        self.ctc_loss_reduction = ctc_loss_reduction
+        self.ctc_zero_infinity = ctc_zero_infinity
+
+        # adapter
+        self.add_adapter = add_adapter
+        self.adapter_kernel_size = adapter_kernel_size
+        self.adapter_stride = adapter_stride
+        self.num_adapter_layers = num_adapter_layers
+        self.output_hidden_size = output_hidden_size or hidden_size
+
+        # SequenceClassification-specific parameter. Feel free to ignore for other classes.
+        self.classifier_proj_size = classifier_proj_size
+
+        # XVector-specific parameters. Feel free to ignore for other classes.
+        self.tdnn_dim = list(tdnn_dim)
+        self.tdnn_kernel = list(tdnn_kernel)
+        self.tdnn_dilation = list(tdnn_dilation)
+        self.xvector_output_dim = xvector_output_dim
+
+    @property
+    def inputs_to_logits_ratio(self):
+        return math.prod(self.conv_stride)
+
+
+__all__ = ["Data2VecAudioConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_text.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9518d67bf665f01a2cfb46cfdb6f529f5f22bce
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_text.py
@@ -0,0 +1,154 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Data2VecText configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Data2VecTextConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`Data2VecTextModel`] and [`Data2VecTextModel`]. It
+    is used to instantiate a Data2VecText model according to the specified arguments, defining the model architecture.
+    Instantiating a configuration with the defaults will yield a similar configuration to that of the Data2VecText
+    [facebook/data2vec-text-base](https://huggingface.co/facebook/data2vec-text-base) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 30522):
+            Vocabulary size of the DATA2VEC model. Defines the number of different tokens that can be represented by
+            the `inputs_ids` passed when calling [`Data2VecModel`].
+        hidden_size (`int`, *optional*, defaults to 768):
+            Dimensionality of the encoder layers and the pooler layer.
+        num_hidden_layers (`int`, *optional*, defaults to 12):
+            Number of hidden layers in the Transformer encoder.
+        num_attention_heads (`int`, *optional*, defaults to 12):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        intermediate_size (`int`, *optional*, defaults to 3072):
+            Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+        hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"silu"` and `"gelu_new"` are supported.
+        hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+            The dropout ratio for the attention probabilities.
+        max_position_embeddings (`int`, *optional*, defaults to 512):
+            The maximum sequence length that this model might ever be used with. Typically set this to something large
+            just in case (e.g., 512 or 1024 or 2048).
+        type_vocab_size (`int`, *optional*, defaults to 2):
+            The vocabulary size of the `token_type_ids` passed when calling [`Data2VecModel`].
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+            The epsilon used by the layer normalization layers.
+        position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+            Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+            positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+            [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
+            For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+            with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
+        is_decoder (`bool`, *optional*, defaults to `False`):
+            Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models). Only
+            relevant if `config.is_decoder=True`.
+        classifier_dropout (`float`, *optional*):
+            The dropout ratio for the classification head.
+
+    Examples:
+
+    ```python
+    >>> from transformers import Data2VecTextConfig, Data2VecTextModel
+
+    >>> # Initializing a Data2VecText facebook/data2vec-text-base style configuration
+    >>> configuration = Data2VecTextConfig()
+
+    >>> # Initializing a model (with random weights) from the facebook/data2vec-text-base style configuration
+    >>> model = Data2VecTextModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "data2vec-text"
+
+    def __init__(
+        self,
+        vocab_size=30522,
+        hidden_size=768,
+        num_hidden_layers=12,
+        num_attention_heads=12,
+        intermediate_size=3072,
+        hidden_act="gelu",
+        hidden_dropout_prob=0.1,
+        attention_probs_dropout_prob=0.1,
+        max_position_embeddings=512,
+        type_vocab_size=2,
+        initializer_range=0.02,
+        layer_norm_eps=1e-12,
+        pad_token_id=1,
+        bos_token_id=0,
+        eos_token_id=2,
+        position_embedding_type="absolute",
+        use_cache=True,
+        classifier_dropout=None,
+        **kwargs,
+    ):
+        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+        self.vocab_size = vocab_size
+        self.hidden_size = hidden_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.hidden_act = hidden_act
+        self.intermediate_size = intermediate_size
+        self.hidden_dropout_prob = hidden_dropout_prob
+        self.attention_probs_dropout_prob = attention_probs_dropout_prob
+        self.max_position_embeddings = max_position_embeddings
+        self.type_vocab_size = type_vocab_size
+        self.initializer_range = initializer_range
+        self.layer_norm_eps = layer_norm_eps
+        self.position_embedding_type = position_embedding_type
+        self.use_cache = use_cache
+        self.classifier_dropout = classifier_dropout
+
+
+class Data2VecTextOnnxConfig(OnnxConfig):
+    @property
+    def inputs(self) -> Mapping[str, Mapping[int, str]]:
+        if self.task == "multiple-choice":
+            dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+        else:
+            dynamic_axis = {0: "batch", 1: "sequence"}
+        return OrderedDict(
+            [
+                ("input_ids", dynamic_axis),
+                ("attention_mask", dynamic_axis),
+            ]
+        )
+
+
+__all__ = ["Data2VecTextConfig", "Data2VecTextOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_vision.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..2de256f9d7d7abda3b065472f116a371c03ee4da
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/configuration_data2vec_vision.py
@@ -0,0 +1,194 @@
+# coding=utf-8
+# Copyright Meta Platforms and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Data2VecVision model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Data2VecVisionConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`Data2VecVisionModel`]. It is used to instantiate
+    an Data2VecVision model according to the specified arguments, defining the model architecture. Instantiating a
+    configuration with the defaults will yield a similar configuration to that of the Data2VecVision
+    [facebook/data2vec-vision-base](https://huggingface.co/facebook/data2vec-vision-base) architecture.
+
+    Args:
+        hidden_size (`int`, *optional*, defaults to 768):
+            Dimensionality of the encoder layers and the pooler layer.
+        num_hidden_layers (`int`, *optional*, defaults to 12):
+            Number of hidden layers in the Transformer encoder.
+        num_attention_heads (`int`, *optional*, defaults to 12):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        intermediate_size (`int`, *optional*, defaults to 3072):
+            Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"selu"` and `"gelu_new"` are supported.
+        hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+            The epsilon used by the layer normalization layers.
+        image_size (`int`, *optional*, defaults to 224):
+            The size (resolution) of each image.
+        patch_size (`int`, *optional*, defaults to 16):
+            The size (resolution) of each patch.
+        num_channels (`int`, *optional*, defaults to 3):
+            The number of input channels.
+        use_mask_token (`bool`, *optional*, defaults to `False`):
+            Whether to use a mask token for masked image modeling.
+        use_absolute_position_embeddings (`bool`, *optional*, defaults to `False`):
+            Whether to use BERT-style absolute position embeddings.
+        use_relative_position_bias (`bool`, *optional*, defaults to `False`):
+            Whether to use T5-style relative position embeddings in the self-attention layers.
+        use_shared_relative_position_bias (`bool`, *optional*, defaults to `False`):
+            Whether to use the same relative position embeddings across all self-attention layers of the Transformer.
+        layer_scale_init_value (`float`, *optional*, defaults to 0.1):
+            Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.
+        drop_path_rate (`float`, *optional*, defaults to 0.1):
+            Stochastic depth rate per sample (when applied in the main path of residual layers).
+        use_mean_pooling (`bool`, *optional*, defaults to `True`):
+            Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
+            CLS token, before applying the classification head.
+        out_indices (`list[int]`, *optional*, defaults to `[3, 5, 7, 11]`):
+            Indices of the feature maps to use for semantic segmentation.
+        pool_scales (`tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`):
+            Pooling scales used in Pooling Pyramid Module applied on the last feature map.
+        use_auxiliary_head (`bool`, *optional*, defaults to `True`):
+            Whether to use an auxiliary head during training.
+        auxiliary_loss_weight (`float`, *optional*, defaults to 0.4):
+            Weight of the cross-entropy loss of the auxiliary head.
+        auxiliary_channels (`int`, *optional*, defaults to 256):
+            Number of channels to use in the auxiliary head.
+        auxiliary_num_convs (`int`, *optional*, defaults to 1):
+            Number of convolutional layers to use in the auxiliary head.
+        auxiliary_concat_input (`bool`, *optional*, defaults to `False`):
+            Whether to concatenate the output of the auxiliary head with the input before the classification layer.
+        semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
+            The index that is ignored by the loss function of the semantic segmentation model.
+
+    Example:
+
+    ```python
+    >>> from transformers import Data2VecVisionConfig, Data2VecVisionModel
+
+    >>> # Initializing a Data2VecVision data2vec_vision-base-patch16-224-in22k style configuration
+    >>> configuration = Data2VecVisionConfig()
+
+    >>> # Initializing a model (with random weights) from the data2vec_vision-base-patch16-224-in22k style configuration
+    >>> model = Data2VecVisionModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "data2vec-vision"
+
+    def __init__(
+        self,
+        hidden_size=768,
+        num_hidden_layers=12,
+        num_attention_heads=12,
+        intermediate_size=3072,
+        hidden_act="gelu",
+        hidden_dropout_prob=0.0,
+        attention_probs_dropout_prob=0.0,
+        initializer_range=0.02,
+        layer_norm_eps=1e-12,
+        image_size=224,
+        patch_size=16,
+        num_channels=3,
+        use_mask_token=False,
+        use_absolute_position_embeddings=False,
+        use_relative_position_bias=False,
+        use_shared_relative_position_bias=False,
+        layer_scale_init_value=0.1,
+        drop_path_rate=0.1,
+        use_mean_pooling=True,
+        out_indices=[3, 5, 7, 11],
+        pool_scales=[1, 2, 3, 6],
+        use_auxiliary_head=True,
+        auxiliary_loss_weight=0.4,
+        auxiliary_channels=256,
+        auxiliary_num_convs=1,
+        auxiliary_concat_input=False,
+        semantic_loss_ignore_index=255,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.hidden_size = hidden_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.intermediate_size = intermediate_size
+        self.hidden_act = hidden_act
+        self.hidden_dropout_prob = hidden_dropout_prob
+        self.attention_probs_dropout_prob = attention_probs_dropout_prob
+        self.initializer_range = initializer_range
+        self.layer_norm_eps = layer_norm_eps
+
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.use_mask_token = use_mask_token
+        self.use_absolute_position_embeddings = use_absolute_position_embeddings
+        self.use_relative_position_bias = use_relative_position_bias
+        self.use_shared_relative_position_bias = use_shared_relative_position_bias
+        self.layer_scale_init_value = layer_scale_init_value
+        self.drop_path_rate = drop_path_rate
+        self.use_mean_pooling = use_mean_pooling
+        # decode head attributes (semantic segmentation)
+        self.out_indices = out_indices
+        self.pool_scales = pool_scales
+        # auxiliary head attributes (semantic segmentation)
+        self.use_auxiliary_head = use_auxiliary_head
+        self.auxiliary_loss_weight = auxiliary_loss_weight
+        self.auxiliary_channels = auxiliary_channels
+        self.auxiliary_num_convs = auxiliary_num_convs
+        self.auxiliary_concat_input = auxiliary_concat_input
+        self.semantic_loss_ignore_index = semantic_loss_ignore_index
+
+
+# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig
+class Data2VecVisionOnnxConfig(OnnxConfig):
+    torch_onnx_minimum_version = version.parse("1.11")
+
+    @property
+    def inputs(self) -> Mapping[str, Mapping[int, str]]:
+        return OrderedDict(
+            [
+                ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
+            ]
+        )
+
+    @property
+    def atol_for_validation(self) -> float:
+        return 1e-4
+
+
+__all__ = ["Data2VecVisionConfig", "Data2VecVisionOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_audio.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9b3f01f42d4ee528c11e12865d0e020ed5f0cf7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_audio.py
@@ -0,0 +1,1397 @@
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+#           This file was automatically generated from src/transformers/models/data2vec/modular_data2vec_audio.py.
+#               Do NOT edit this file manually as any edits will be overwritten by the generation of
+#             the file from the modular. If any change should be done, please apply the change to the
+#                          modular_data2vec_audio.py file directly. One of our CI enforces this.
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import warnings
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...integrations.fsdp import is_fsdp_managed_module
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+    BaseModelOutput,
+    CausalLMOutput,
+    SequenceClassifierOutput,
+    TokenClassifierOutput,
+    Wav2Vec2BaseModelOutput,
+    XVectorOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available
+from .configuration_data2vec_audio import Data2VecAudioConfig
+
+
+if is_torch_flex_attn_available():
+    from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+class Data2VecAudioConvLayer(GradientCheckpointingLayer):
+    def __init__(self, config, layer_id=0):
+        super().__init__()
+        self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+        self.out_conv_dim = config.conv_dim[layer_id]
+
+        self.conv = nn.Conv1d(
+            self.in_conv_dim,
+            self.out_conv_dim,
+            kernel_size=config.conv_kernel[layer_id],
+            stride=config.conv_stride[layer_id],
+            bias=config.conv_bias,
+        )
+        self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+        self.activation = ACT2FN[config.feat_extract_activation]
+
+    def forward(self, hidden_states):
+        hidden_states = self.conv(hidden_states)
+
+        hidden_states = hidden_states.transpose(-2, -1)
+        hidden_states = self.layer_norm(hidden_states)
+        hidden_states = hidden_states.transpose(-2, -1)
+
+        hidden_states = self.activation(hidden_states)
+        return hidden_states
+
+
+class Data2VecAudioPadLayer(nn.Module):
+    def __init__(self, num_conv_pos_embeddings):
+        super().__init__()
+        self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
+
+    def forward(self, hidden_states):
+        if self.num_pad_remove > 0:
+            hidden_states = hidden_states[:, :, : -self.num_pad_remove]
+        return hidden_states
+
+
+class Data2VecAudioPositionalConvLayer(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.conv = nn.Conv1d(
+            config.hidden_size,
+            config.hidden_size,
+            kernel_size=config.conv_pos_kernel_size,
+            padding=config.conv_pos_kernel_size // 2,
+            groups=config.num_conv_pos_embedding_groups,
+        )
+
+        self.padding = Data2VecAudioPadLayer(config.conv_pos_kernel_size)
+        self.activation = ACT2FN[config.feat_extract_activation]
+        # no learnable parameters
+        self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
+
+    def forward(self, hidden_states):
+        hidden_states = self.conv(hidden_states)
+        hidden_states = self.padding(hidden_states)
+
+        hidden_states = hidden_states.transpose(1, 2)
+        hidden_states = self.layer_norm(hidden_states)
+        hidden_states = hidden_states.transpose(1, 2)
+        hidden_states = self.activation(hidden_states)
+        return hidden_states
+
+
+class Data2VecAudioPositionalConvEmbedding(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [Data2VecAudioPositionalConvLayer(config) for _ in range(config.num_conv_pos_embeddings)]
+        )
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.transpose(1, 2)
+        for layer in self.layers:
+            hidden_states = layer(hidden_states)
+        hidden_states = hidden_states.transpose(1, 2)
+        return hidden_states
+
+
+class Data2VecAudioFeatureEncoder(nn.Module):
+    """Construct the features from raw audio waveform"""
+
+    def __init__(self, config):
+        super().__init__()
+        self.conv_layers = nn.ModuleList(
+            [Data2VecAudioConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
+        )
+        self.gradient_checkpointing = False
+        self._requires_grad = True
+
+    def _freeze_parameters(self):
+        for param in self.parameters():
+            param.requires_grad = False
+        self._requires_grad = False
+
+    def forward(self, input_values):
+        hidden_states = input_values[:, None]
+
+        # make sure hidden_states require grad for gradient_checkpointing
+        if self._requires_grad and self.training:
+            hidden_states.requires_grad = True
+
+        for conv_layer in self.conv_layers:
+            hidden_states = conv_layer(hidden_states)
+
+        return hidden_states
+
+
+class Data2VecAudioFeatureProjection(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
+        self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
+        self.dropout = nn.Dropout(config.feat_proj_dropout)
+
+    def forward(self, hidden_states):
+        # non-projected hidden states are needed for quantization
+        norm_hidden_states = self.layer_norm(hidden_states)
+        hidden_states = self.projection(norm_hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        return hidden_states, norm_hidden_states
+
+
+def eager_attention_forward(
+    module: nn.Module,
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attention_mask: Optional[torch.Tensor],
+    scaling: Optional[float] = None,
+    dropout: float = 0.0,
+    head_mask: Optional[torch.Tensor] = None,
+    **kwargs,
+):
+    if scaling is None:
+        scaling = query.size(-1) ** -0.5
+
+    attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+    if attention_mask is not None:
+        attn_weights = attn_weights + attention_mask
+
+    attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+    if head_mask is not None:
+        attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+    attn_output = torch.matmul(attn_weights, value)
+    attn_output = attn_output.transpose(1, 2).contiguous()
+
+    return attn_output, attn_weights
+
+
+class Data2VecAudioAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        dropout: float = 0.0,
+        is_decoder: bool = False,
+        bias: bool = True,
+        is_causal: bool = False,
+        config: Optional[Data2VecAudioConfig] = None,
+    ):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.head_dim = embed_dim // num_heads
+        self.config = config
+
+        if (self.head_dim * num_heads) != self.embed_dim:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+                f" and `num_heads`: {num_heads})."
+            )
+        self.scaling = self.head_dim**-0.5
+        self.is_decoder = is_decoder
+        self.is_causal = is_causal
+
+        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        key_value_states: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = False,
+        # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+        # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+        **kwargs: Unpack[FlashAttentionKwargs],
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+        """Input shape: Batch x Time x Channel"""
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+
+        # determine input shapes
+        bsz, tgt_len = hidden_states.shape[:-1]
+        src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+        q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+        kv_input_shape = (bsz, src_len, -1, self.head_dim)
+
+        # get query proj
+        query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
+
+        current_states = key_value_states if is_cross_attention else hidden_states
+        key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
+        value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
+
+        attention_interface: Callable = eager_attention_forward
+        if self.config._attn_implementation != "eager":
+            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+        attn_output, attn_weights = attention_interface(
+            self,
+            query_states,
+            key_states,
+            value_states,
+            attention_mask,
+            dropout=0.0 if not self.training else self.dropout,
+            scaling=self.scaling,
+            output_attentions=output_attentions,
+            head_mask=layer_head_mask,
+            **kwargs,
+        )
+
+        attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, attn_weights, None
+
+
+class Data2VecAudioFeedForward(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.intermediate_dropout = nn.Dropout(config.activation_dropout)
+
+        self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+        self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
+        self.output_dropout = nn.Dropout(config.hidden_dropout)
+
+    def forward(self, hidden_states):
+        hidden_states = self.intermediate_dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+        hidden_states = self.intermediate_dropout(hidden_states)
+
+        hidden_states = self.output_dense(hidden_states)
+        hidden_states = self.output_dropout(hidden_states)
+        return hidden_states
+
+
+class Data2VecAudioEncoderLayer(GradientCheckpointingLayer):
+    def __init__(self, config):
+        super().__init__()
+        self.attention = Data2VecAudioAttention(
+            embed_dim=config.hidden_size,
+            num_heads=config.num_attention_heads,
+            dropout=config.attention_dropout,
+            is_decoder=False,
+            config=config,
+        )
+
+        self.dropout = nn.Dropout(config.hidden_dropout)
+        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.feed_forward = Data2VecAudioFeedForward(config)
+        self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+    def forward(self, hidden_states, attention_mask=None, output_attentions=False):
+        attn_residual = hidden_states
+        hidden_states, attn_weights, _ = self.attention(
+            hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
+        )
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = attn_residual + hidden_states
+
+        hidden_states = self.layer_norm(hidden_states)
+        hidden_states = hidden_states + self.feed_forward(hidden_states)
+        hidden_states = self.final_layer_norm(hidden_states)
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (attn_weights,)
+
+        return outputs
+
+
+class Data2VecAudioEncoder(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.pos_conv_embed = Data2VecAudioPositionalConvEmbedding(config)
+        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout)
+        self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+        output_hidden_states: bool = False,
+        return_dict: bool = True,
+    ):
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+
+        if attention_mask is not None:
+            # make sure padded tokens output 0
+            expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+            hidden_states[~expand_attention_mask] = 0
+
+        attention_mask = self._update_full_mask(
+            attention_mask,
+            hidden_states,
+        )
+
+        position_embeddings = self.pos_conv_embed(hidden_states)
+        hidden_states = hidden_states + position_embeddings
+        hidden_states = self.layer_norm(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+
+        synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
+
+        for layer in self.layers:
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
+            dropout_probability = torch.rand([])
+
+            skip_the_layer = self.training and dropout_probability < self.config.layerdrop
+            if not skip_the_layer or synced_gpus:
+                # under fsdp or deepspeed zero3 all gpus must run in sync
+                layer_outputs = layer(
+                    hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
+                )
+                hidden_states = layer_outputs[0]
+
+            if skip_the_layer:
+                layer_outputs = (None, None)
+
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+        return BaseModelOutput(
+            last_hidden_state=hidden_states,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+        )
+
+    def _update_full_mask(
+        self,
+        attention_mask: Union[torch.Tensor, None],
+        inputs_embeds: torch.Tensor,
+    ):
+        if attention_mask is not None:
+            if self.config._attn_implementation == "flash_attention_2":
+                attention_mask = attention_mask if 0 in attention_mask else None
+            elif self.config._attn_implementation == "sdpa":
+                # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+                # the manual implementation that requires a 4D causal mask in all cases.
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+            elif self.config._attn_implementation == "flex_attention":
+                if isinstance(attention_mask, torch.Tensor):
+                    attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+            else:
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+        return attention_mask
+
+
+class Data2VecAudioAdapterLayer(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.conv = nn.Conv1d(
+            config.output_hidden_size,
+            2 * config.output_hidden_size,
+            config.adapter_kernel_size,
+            stride=config.adapter_stride,
+            padding=1,
+        )
+
+    def forward(self, hidden_states):
+        hidden_states = self.conv(hidden_states)
+        hidden_states = nn.functional.glu(hidden_states, dim=1)
+
+        return hidden_states
+
+
+class Data2VecAudioAdapter(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+
+        # feature dim might need to be down-projected
+        if config.output_hidden_size != config.hidden_size:
+            self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
+            self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
+        else:
+            self.proj = self.proj_layer_norm = None
+
+        self.layers = nn.ModuleList(Data2VecAudioAdapterLayer(config) for _ in range(config.num_adapter_layers))
+        self.layerdrop = config.layerdrop
+
+    def forward(self, hidden_states):
+        # down project hidden_states if necessary
+        if self.proj is not None and self.proj_layer_norm is not None:
+            hidden_states = self.proj(hidden_states)
+            hidden_states = self.proj_layer_norm(hidden_states)
+
+        hidden_states = hidden_states.transpose(1, 2)
+
+        for layer in self.layers:
+            layerdrop_prob = np.random.random()
+            if not self.training or (layerdrop_prob > self.layerdrop):
+                hidden_states = layer(hidden_states)
+
+        hidden_states = hidden_states.transpose(1, 2)
+        return hidden_states
+
+
+@auto_docstring
+class Data2VecAudioPreTrainedModel(PreTrainedModel):
+    config: Data2VecAudioConfig
+    base_model_prefix = "data2vec_audio"
+    main_input_name = "input_values"
+    supports_gradient_checkpointing = True
+    _supports_flash_attn = True
+    _supports_sdpa = True
+    _supports_flex_attn = True
+
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        if isinstance(module, Data2VecAudioFeatureProjection):
+            k = math.sqrt(1 / module.projection.in_features)
+            nn.init.uniform_(module.projection.weight, a=-k, b=k)
+            nn.init.uniform_(module.projection.bias, a=-k, b=k)
+        elif isinstance(module, Data2VecAudioPositionalConvLayer):
+            nn.init.constant_(module.conv.bias, 0)
+        elif isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+            if module.bias is not None:
+                module.bias.data.zero_()
+            if module.weight is not None:
+                module.weight.data.fill_(1.0)
+        elif isinstance(module, nn.Conv1d):
+            nn.init.kaiming_normal_(module.weight)
+
+            if module.bias is not None:
+                k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+                nn.init.uniform_(module.bias, a=-k, b=k)
+
+    def _get_feat_extract_output_lengths(
+        self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
+    ):
+        """
+        Computes the output length of the convolutional layers
+        """
+
+        add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
+
+        def _conv_out_length(input_length, kernel_size, stride):
+            # 1D convolutional layer output length formula taken
+            # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+            return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
+
+        for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+            input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+
+        if add_adapter:
+            for _ in range(self.config.num_adapter_layers):
+                input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
+
+        return input_lengths
+
+    def _get_feature_vector_attention_mask(
+        self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
+    ):
+        # Effectively attention_mask.sum(-1), but not inplace to be able to run
+        # on inference mode.
+        non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
+
+        output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
+        output_lengths = output_lengths.to(torch.long)
+
+        batch_size = attention_mask.shape[0]
+
+        attention_mask = torch.zeros(
+            (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+        )
+        # these two operations makes sure that all values before the output lengths idxs are attended to
+        attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+        attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+        return attention_mask
+
+
+def _compute_mask_indices(
+    shape: tuple[int, int],
+    mask_prob: float,
+    mask_length: int,
+    attention_mask: Optional[torch.LongTensor] = None,
+    min_masks: int = 0,
+) -> np.ndarray:
+    """
+    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+    ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+    CPU as part of the preprocessing during training.
+
+    Args:
+        shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+               the first element is the batch size and the second element is the length of the axis to span.
+        mask_prob:  The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+                    independently generated mask spans of length `mask_length` is computed by
+                    `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+                    actual percentage will be smaller.
+        mask_length: size of the mask
+        min_masks: minimum number of masked spans
+        attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+                        each batch dimension.
+    """
+    batch_size, sequence_length = shape
+
+    if mask_length < 1:
+        raise ValueError("`mask_length` has to be bigger than 0.")
+
+    if mask_length > sequence_length:
+        raise ValueError(
+            f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+            f" and `sequence_length`: {sequence_length}`"
+        )
+
+    # epsilon is used for probabilistic rounding
+    epsilon = np.random.rand(1).item()
+
+    def compute_num_masked_span(input_length):
+        """Given input length, compute how many spans should be masked"""
+        num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+        num_masked_span = max(num_masked_span, min_masks)
+
+        # make sure num masked span <= sequence_length
+        if num_masked_span * mask_length > sequence_length:
+            num_masked_span = sequence_length // mask_length
+
+        # make sure num_masked span is also <= input_length - (mask_length - 1)
+        if input_length - (mask_length - 1) < num_masked_span:
+            num_masked_span = max(input_length - (mask_length - 1), 0)
+
+        return num_masked_span
+
+    # compute number of masked spans in batch
+    input_lengths = (
+        attention_mask.detach().sum(-1).tolist()
+        if attention_mask is not None
+        else [sequence_length for _ in range(batch_size)]
+    )
+
+    # SpecAugment mask to fill
+    spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+    spec_aug_mask_idxs = []
+
+    max_num_masked_span = compute_num_masked_span(sequence_length)
+
+    if max_num_masked_span == 0:
+        return spec_aug_mask
+
+    for input_length in input_lengths:
+        # compute num of masked spans for this input
+        num_masked_span = compute_num_masked_span(input_length)
+
+        # get random indices to mask
+        spec_aug_mask_idx = np.random.choice(
+            np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+        )
+
+        # pick first sampled index that will serve as a dummy index to pad vector
+        # to ensure same dimension for all batches due to probabilistic rounding
+        # Picking first sample just pads those vectors twice.
+        if len(spec_aug_mask_idx) == 0:
+            # this case can only happen if `input_length` is strictly smaller then
+            # `sequence_length` in which case the last token has to be a padding
+            # token which we can use as a dummy mask id
+            dummy_mask_idx = sequence_length - 1
+        else:
+            dummy_mask_idx = spec_aug_mask_idx[0]
+
+        spec_aug_mask_idx = np.concatenate(
+            [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+        )
+        spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+    spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+    # expand masked indices to masked spans
+    spec_aug_mask_idxs = np.broadcast_to(
+        spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+    )
+    spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+    # add offset to the starting indexes so that indexes now create a span
+    offsets = np.arange(mask_length)[None, None, :]
+    offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+        batch_size, max_num_masked_span * mask_length
+    )
+    spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+    # ensure that we cannot have indices larger than sequence_length
+    if spec_aug_mask_idxs.max() > sequence_length - 1:
+        spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+    # scatter indices to mask
+    np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+    return spec_aug_mask
+
+
+Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput
+
+
+@auto_docstring
+class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
+    def __init__(self, config: Data2VecAudioConfig):
+        super().__init__(config)
+        self.config = config
+        self.feature_extractor = Data2VecAudioFeatureEncoder(config)
+        self.feature_projection = Data2VecAudioFeatureProjection(config)
+
+        # model only needs masking vector if mask prob is > 0.0
+        if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+            self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
+
+        self.encoder = Data2VecAudioEncoder(config)
+
+        self.adapter = Data2VecAudioAdapter(config) if config.add_adapter else None
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def freeze_feature_encoder(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+        not be updated during training.
+        """
+        self.feature_extractor._freeze_parameters()
+
+    def _mask_hidden_states(
+        self,
+        hidden_states: torch.FloatTensor,
+        mask_time_indices: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+    ):
+        """
+        Masks extracted features along time axis and/or along feature axis according to
+        [SpecAugment](https://huggingface.co/papers/1904.08779).
+        """
+
+        # `config.apply_spec_augment` can set masking to False
+        if not getattr(self.config, "apply_spec_augment", True):
+            return hidden_states
+
+        # generate indices & apply SpecAugment along time axis
+        batch_size, sequence_length, hidden_size = hidden_states.size()
+
+        if mask_time_indices is not None:
+            # apply SpecAugment along time axis with given mask_time_indices
+            hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+        elif self.config.mask_time_prob > 0 and self.training:
+            mask_time_indices = _compute_mask_indices(
+                (batch_size, sequence_length),
+                mask_prob=self.config.mask_time_prob,
+                mask_length=self.config.mask_time_length,
+                attention_mask=attention_mask,
+                min_masks=self.config.mask_time_min_masks,
+            )
+            mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
+            hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+
+        if self.config.mask_feature_prob > 0 and self.training:
+            # generate indices & apply SpecAugment along feature axis
+            mask_feature_indices = _compute_mask_indices(
+                (batch_size, hidden_size),
+                mask_prob=self.config.mask_feature_prob,
+                mask_length=self.config.mask_feature_length,
+                min_masks=self.config.mask_feature_min_masks,
+            )
+            mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
+            mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
+            hidden_states[mask_feature_indices] = 0
+
+        return hidden_states
+
+    @auto_docstring
+    def forward(
+        self,
+        input_values: Optional[torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        mask_time_indices: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[tuple, Data2VecAudioBaseModelOutput]:
+        r"""
+        mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
+            masked extracted features in *config.proj_codevector_dim* space.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        extract_features = self.feature_extractor(input_values)
+        extract_features = extract_features.transpose(1, 2)
+
+        if attention_mask is not None:
+            # compute reduced attention_mask corresponding to feature vectors
+            attention_mask = self._get_feature_vector_attention_mask(
+                extract_features.shape[1], attention_mask, add_adapter=False
+            )
+
+        hidden_states, extract_features = self.feature_projection(extract_features)
+        hidden_states = self._mask_hidden_states(
+            hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+        )
+
+        encoder_outputs = self.encoder(
+            hidden_states,
+            attention_mask=attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        hidden_states = encoder_outputs[0]
+
+        if self.adapter is not None:
+            hidden_states = self.adapter(hidden_states)
+
+        if not return_dict:
+            return (hidden_states, extract_features) + encoder_outputs[1:]
+
+        return Data2VecAudioBaseModelOutput(
+            last_hidden_state=hidden_states,
+            extract_features=extract_features,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+        )
+
+
+_HIDDEN_STATES_START_POSITION = 2
+
+
+@auto_docstring(
+    custom_intro="""
+    Data2VecAudio Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
+    """
+)
+class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel):
+    def __init__(self, config):
+        r"""
+        target_lang (`str`, *optional*):
+            Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or
+            adapter..bin. Only relevant when using an instance of [`Data2VecAudioForCTC`] with adapters. Uses 'eng' by
+            default.
+        """
+        super().__init__(config)
+
+        self.data2vec_audio = Data2VecAudioModel(config)
+        self.dropout = nn.Dropout(config.final_dropout)
+
+        if config.vocab_size is None:
+            raise ValueError(
+                f"You are trying to instantiate {self.__class__} with a configuration that "
+                "does not define the vocabulary size of the language model head. Please "
+                "instantiate the model as follows: `Data2VecAudioForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
+                "or define `vocab_size` of your model's configuration."
+            )
+        output_hidden_size = (
+            config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
+        )
+        self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def freeze_feature_extractor(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+        not be updated during training.
+        """
+        warnings.warn(
+            "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
+            "Please use the equivalent `freeze_feature_encoder` method instead.",
+            FutureWarning,
+        )
+        self.freeze_feature_encoder()
+
+    def freeze_feature_encoder(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+        not be updated during training.
+        """
+        self.data2vec_audio.feature_extractor._freeze_parameters()
+
+    @auto_docstring
+    def forward(
+        self,
+        input_values: Optional[torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        labels: Optional[torch.Tensor] = None,
+    ) -> Union[tuple, CausalLMOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
+            Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
+            the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
+            All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
+            config.vocab_size - 1]`.
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if labels is not None and labels.max() >= self.config.vocab_size:
+            raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+        outputs = self.data2vec_audio(
+            input_values,
+            attention_mask=attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        hidden_states = outputs[0]
+        hidden_states = self.dropout(hidden_states)
+
+        logits = self.lm_head(hidden_states)
+
+        loss = None
+        if labels is not None:
+            # retrieve loss input_lengths from attention_mask
+            attention_mask = (
+                attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
+            )
+            input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+
+            # assuming that padded tokens are filled with -100
+            # when not being attended to
+            labels_mask = labels >= 0
+            target_lengths = labels_mask.sum(-1)
+            flattened_targets = labels.masked_select(labels_mask)
+
+            # ctc_loss doesn't support fp16
+            log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+            with torch.backends.cudnn.flags(enabled=False):
+                loss = nn.functional.ctc_loss(
+                    log_probs,
+                    flattened_targets,
+                    input_lengths,
+                    target_lengths,
+                    blank=self.config.pad_token_id,
+                    reduction=self.config.ctc_loss_reduction,
+                    zero_infinity=self.config.ctc_zero_infinity,
+                )
+
+        if not return_dict:
+            output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+            return ((loss,) + output) if loss is not None else output
+
+        return CausalLMOutput(
+            loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+        )
+
+
+@auto_docstring(
+    custom_intro="""
+    Data2VecAudio Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
+    SUPERB Keyword Spotting.
+    """
+)
+class Data2VecAudioForSequenceClassification(Data2VecAudioPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+
+        if hasattr(config, "add_adapter") and config.add_adapter:
+            raise ValueError(
+                "Sequence classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)"
+            )
+        self.data2vec_audio = Data2VecAudioModel(config)
+        num_layers = config.num_hidden_layers + 1  # transformer layers + input embeddings
+        if config.use_weighted_layer_sum:
+            self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+        self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
+        self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def freeze_feature_extractor(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameters will
+        not be updated during training.
+        """
+        warnings.warn(
+            "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
+            "Please use the equivalent `freeze_feature_encoder` method instead.",
+            FutureWarning,
+        )
+        self.freeze_feature_encoder()
+
+    def freeze_feature_encoder(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+        not be updated during training.
+        """
+        self.data2vec_audio.feature_extractor._freeze_parameters()
+
+    def freeze_base_model(self):
+        """
+        Calling this function will disable the gradient computation for the base model so that its parameters will not
+        be updated during training. Only the classification head will be updated.
+        """
+        for param in self.data2vec_audio.parameters():
+            param.requires_grad = False
+
+    @auto_docstring
+    def forward(
+        self,
+        input_values: Optional[torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        labels: Optional[torch.Tensor] = None,
+    ) -> Union[tuple, SequenceClassifierOutput]:
+        r"""
+        input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+            Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+            into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
+            (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
+            To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
+            into a tensor of type `torch.FloatTensor`. See [`Data2VecAudioProcessor.__call__`] for details.
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+        outputs = self.data2vec_audio(
+            input_values,
+            attention_mask=attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        if self.config.use_weighted_layer_sum:
+            hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+            hidden_states = torch.stack(hidden_states, dim=1)
+            norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+            hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+        else:
+            hidden_states = outputs[0]
+
+        hidden_states = self.projector(hidden_states)
+        if attention_mask is None:
+            pooled_output = hidden_states.mean(dim=1)
+        else:
+            padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
+            expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+            hidden_states[~expand_padding_mask] = 0.0
+            pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
+
+        logits = self.classifier(pooled_output)
+
+        loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+
+        if not return_dict:
+            output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+            return ((loss,) + output) if loss is not None else output
+
+        return SequenceClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+@auto_docstring
+class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+
+        if hasattr(config, "add_adapter") and config.add_adapter:
+            raise ValueError(
+                "Audio frame classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)"
+            )
+        self.data2vec_audio = Data2VecAudioModel(config)
+        num_layers = config.num_hidden_layers + 1  # transformer layers + input embeddings
+        if config.use_weighted_layer_sum:
+            self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+        self.num_labels = config.num_labels
+
+        self.init_weights()
+
+    def freeze_feature_extractor(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+        not be updated during training.
+        """
+        warnings.warn(
+            "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
+            "Please use the equivalent `freeze_feature_encoder` method instead.",
+            FutureWarning,
+        )
+        self.freeze_feature_encoder()
+
+    def freeze_feature_encoder(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+        not be updated during training.
+        """
+        self.data2vec_audio.feature_extractor._freeze_parameters()
+
+    def freeze_base_model(self):
+        """
+        Calling this function will disable the gradient computation for the base model so that its parameters will not
+        be updated during training. Only the classification head will be updated.
+        """
+        for param in self.data2vec_audio.parameters():
+            param.requires_grad = False
+
+    @auto_docstring
+    def forward(
+        self,
+        input_values: Optional[torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        labels: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[tuple, TokenClassifierOutput]:
+        r"""
+        input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+            Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+            into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
+            (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
+            To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
+            into a tensor of type `torch.FloatTensor`. See [`Data2VecAudioProcessor.__call__`] for details.
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+        outputs = self.data2vec_audio(
+            input_values,
+            attention_mask=attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        if self.config.use_weighted_layer_sum:
+            hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+            hidden_states = torch.stack(hidden_states, dim=1)
+            norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+            hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+        else:
+            hidden_states = outputs[0]
+
+        logits = self.classifier(hidden_states)
+
+        loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
+        if not return_dict:
+            output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+            return output
+
+        return TokenClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+class AMSoftmaxLoss(nn.Module):
+    def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
+        super().__init__()
+        self.scale = scale
+        self.margin = margin
+        self.num_labels = num_labels
+        self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
+        self.loss = nn.CrossEntropyLoss()
+
+    def forward(self, hidden_states, labels):
+        labels = labels.flatten()
+        weight = nn.functional.normalize(self.weight, dim=0)
+        hidden_states = nn.functional.normalize(hidden_states, dim=1)
+        cos_theta = torch.mm(hidden_states, weight)
+        psi = cos_theta - self.margin
+
+        onehot = nn.functional.one_hot(labels, self.num_labels)
+        logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
+        loss = self.loss(logits, labels)
+
+        return loss
+
+
+class TDNNLayer(nn.Module):
+    def __init__(self, config, layer_id=0):
+        super().__init__()
+        self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
+        self.out_conv_dim = config.tdnn_dim[layer_id]
+        self.kernel_size = config.tdnn_kernel[layer_id]
+        self.dilation = config.tdnn_dilation[layer_id]
+
+        self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
+        self.activation = nn.ReLU()
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        if is_peft_available():
+            from peft.tuners.lora import LoraLayer
+
+        if is_peft_available():
+            if isinstance(self.kernel, LoraLayer):
+                warnings.warn(
+                    "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
+                    "You should exclude TDNNLayer from LoRA's target modules.",
+                )
+
+        # for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
+        hidden_states = hidden_states.transpose(1, 2)
+        weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
+        hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
+        hidden_states = hidden_states.transpose(1, 2)
+
+        hidden_states = self.activation(hidden_states)
+        return hidden_states
+
+
+@auto_docstring(
+    custom_intro="""
+    Data2VecAudio Model with an XVector feature extraction head on top for tasks like Speaker Verification.
+    """
+)
+class Data2VecAudioForXVector(Data2VecAudioPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.data2vec_audio = Data2VecAudioModel(config)
+        num_layers = config.num_hidden_layers + 1  # transformer layers + input embeddings
+        if config.use_weighted_layer_sum:
+            self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+        self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
+
+        tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
+        self.tdnn = nn.ModuleList(tdnn_layers)
+
+        self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
+        self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
+
+        self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
+
+        self.init_weights()
+
+    def freeze_feature_extractor(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+        not be updated during training.
+        """
+        warnings.warn(
+            "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
+            "Please use the equivalent `freeze_feature_encoder` method instead.",
+            FutureWarning,
+        )
+        self.freeze_feature_encoder()
+
+    def freeze_feature_encoder(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+        not be updated during training.
+        """
+        self.data2vec_audio.feature_extractor._freeze_parameters()
+
+    def freeze_base_model(self):
+        """
+        Calling this function will disable the gradient computation for the base model so that its parameters will not
+        be updated during training. Only the classification head will be updated.
+        """
+        for param in self.data2vec_audio.parameters():
+            param.requires_grad = False
+
+    def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
+        """
+        Computes the output length of the TDNN layers
+        """
+
+        def _conv_out_length(input_length, kernel_size, stride):
+            # 1D convolutional layer output length formula taken
+            # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+            return (input_length - kernel_size) // stride + 1
+
+        for kernel_size in self.config.tdnn_kernel:
+            input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
+
+        return input_lengths
+
+    @auto_docstring
+    def forward(
+        self,
+        input_values: Optional[torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        labels: Optional[torch.Tensor] = None,
+    ) -> Union[tuple, XVectorOutput]:
+        r"""
+        input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+            Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+            into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
+            (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
+            To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
+            into a tensor of type `torch.FloatTensor`. See [`Data2VecAudioProcessor.__call__`] for details.
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+        outputs = self.data2vec_audio(
+            input_values,
+            attention_mask=attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        if self.config.use_weighted_layer_sum:
+            hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+            hidden_states = torch.stack(hidden_states, dim=1)
+            norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+            hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+        else:
+            hidden_states = outputs[0]
+
+        hidden_states = self.projector(hidden_states)
+
+        for tdnn_layer in self.tdnn:
+            hidden_states = tdnn_layer(hidden_states)
+
+        # Statistic Pooling
+        if attention_mask is None:
+            mean_features = hidden_states.mean(dim=1)
+            std_features = hidden_states.std(dim=1)
+        else:
+            feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
+            tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
+            mean_features = []
+            std_features = []
+            for i, length in enumerate(tdnn_output_lengths):
+                mean_features.append(hidden_states[i, :length].mean(dim=0))
+                std_features.append(hidden_states[i, :length].std(dim=0))
+            mean_features = torch.stack(mean_features)
+            std_features = torch.stack(std_features)
+        statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
+
+        output_embeddings = self.feature_extractor(statistic_pooling)
+        logits = self.classifier(output_embeddings)
+
+        loss = None
+        if labels is not None:
+            loss = self.objective(logits, labels)
+
+        if not return_dict:
+            output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
+            return ((loss,) + output) if loss is not None else output
+
+        return XVectorOutput(
+            loss=loss,
+            logits=logits,
+            embeddings=output_embeddings,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+__all__ = [
+    "Data2VecAudioForAudioFrameClassification",
+    "Data2VecAudioForCTC",
+    "Data2VecAudioForSequenceClassification",
+    "Data2VecAudioForXVector",
+    "Data2VecAudioModel",
+    "Data2VecAudioPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_text.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..f866dd9144a627b938de67ed3d3f79816e849722
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_text.py
@@ -0,0 +1,1378 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Data2VecText model."""
+
+import math
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN, gelu
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...generation import GenerationMixin
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+    BaseModelOutputWithPastAndCrossAttentions,
+    BaseModelOutputWithPoolingAndCrossAttentions,
+    CausalLMOutputWithCrossAttentions,
+    MaskedLMOutput,
+    MultipleChoiceModelOutput,
+    QuestionAnsweringModelOutput,
+    SequenceClassifierOutput,
+    TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import auto_docstring, logging
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_data2vec_text import Data2VecTextConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_HIDDEN_STATES_START_POSITION = 2
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->Data2VecText
+class Data2VecTextForTextEmbeddings(nn.Module):
+    """
+    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+    """
+
+    # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
+    def __init__(self, config):
+        super().__init__()
+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+        # any TensorFlow checkpoint file
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+        self.register_buffer(
+            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+        )
+        self.register_buffer(
+            "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
+        )
+
+        # End copy
+        self.padding_idx = config.pad_token_id
+        self.position_embeddings = nn.Embedding(
+            config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
+        )
+
+    def forward(
+        self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+    ):
+        if position_ids is None:
+            if input_ids is not None:
+                # Create the position ids from the input token ids. Any padded tokens remain padded.
+                position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
+            else:
+                position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+        if input_ids is not None:
+            input_shape = input_ids.size()
+        else:
+            input_shape = inputs_embeds.size()[:-1]
+
+        seq_length = input_shape[1]
+
+        # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+        # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+        # issue #5664
+        if token_type_ids is None:
+            if hasattr(self, "token_type_ids"):
+                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+                token_type_ids = buffered_token_type_ids_expanded
+            else:
+                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+        token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+        embeddings = inputs_embeds + token_type_embeddings
+        if self.position_embedding_type == "absolute":
+            position_embeddings = self.position_embeddings(position_ids)
+            embeddings += position_embeddings
+        embeddings = self.LayerNorm(embeddings)
+        embeddings = self.dropout(embeddings)
+        return embeddings
+
+    def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+        """
+        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+
+        Args:
+            inputs_embeds: torch.Tensor
+
+        Returns: torch.Tensor
+        """
+        input_shape = inputs_embeds.size()[:-1]
+        sequence_length = input_shape[1]
+
+        position_ids = torch.arange(
+            self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
+        )
+        return position_ids.unsqueeze(0).expand(input_shape)
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Data2VecText
+class Data2VecTextSelfAttention(nn.Module):
+    def __init__(self, config, position_embedding_type=None, layer_idx=None):
+        super().__init__()
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+            raise ValueError(
+                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+                f"heads ({config.num_attention_heads})"
+            )
+
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+        self.query = nn.Linear(config.hidden_size, self.all_head_size)
+        self.key = nn.Linear(config.hidden_size, self.all_head_size)
+        self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+        self.position_embedding_type = position_embedding_type or getattr(
+            config, "position_embedding_type", "absolute"
+        )
+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+            self.max_position_embeddings = config.max_position_embeddings
+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+        self.is_decoder = config.is_decoder
+        self.layer_idx = layer_idx
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        output_attentions: Optional[bool] = False,
+        cache_position: Optional[torch.Tensor] = None,
+    ) -> tuple[torch.Tensor]:
+        batch_size, seq_length, _ = hidden_states.shape
+        query_layer = self.query(hidden_states)
+        query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
+            1, 2
+        )
+
+        is_updated = False
+        is_cross_attention = encoder_hidden_states is not None
+        if past_key_values is not None:
+            if isinstance(past_key_values, EncoderDecoderCache):
+                is_updated = past_key_values.is_updated.get(self.layer_idx)
+                if is_cross_attention:
+                    # after the first generated id, we can subsequently re-use all key/value_layer from cache
+                    curr_past_key_value = past_key_values.cross_attention_cache
+                else:
+                    curr_past_key_value = past_key_values.self_attention_cache
+            else:
+                curr_past_key_value = past_key_values
+
+        current_states = encoder_hidden_states if is_cross_attention else hidden_states
+        if is_cross_attention and past_key_values is not None and is_updated:
+            # reuse k,v, cross_attentions
+            key_layer = curr_past_key_value.layers[self.layer_idx].keys
+            value_layer = curr_past_key_value.layers[self.layer_idx].values
+        else:
+            key_layer = self.key(current_states)
+            key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
+                1, 2
+            )
+            value_layer = self.value(current_states)
+            value_layer = value_layer.view(
+                batch_size, -1, self.num_attention_heads, self.attention_head_size
+            ).transpose(1, 2)
+
+            if past_key_values is not None:
+                # save all key/value_layer to cache to be re-used for fast auto-regressive generation
+                cache_position = cache_position if not is_cross_attention else None
+                key_layer, value_layer = curr_past_key_value.update(
+                    key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
+                )
+                # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+                if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
+                    past_key_values.is_updated[self.layer_idx] = True
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+            query_length, key_length = query_layer.shape[2], key_layer.shape[2]
+            if past_key_values is not None:
+                position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
+                    -1, 1
+                )
+            else:
+                position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+            position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+            distance = position_ids_l - position_ids_r
+
+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility
+
+            if self.position_embedding_type == "relative_key":
+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+                attention_scores = attention_scores + relative_position_scores
+            elif self.position_embedding_type == "relative_key_query":
+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+        if attention_mask is not None:
+            # Apply the attention mask is (precomputed for all layers in Data2VecTextModel forward() function)
+            attention_scores = attention_scores + attention_mask
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs = self.dropout(attention_probs)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs = attention_probs * head_mask
+
+        context_layer = torch.matmul(attention_probs, value_layer)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+
+        return context_layer, attention_probs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
+class Data2VecTextSelfOutput(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        return hidden_states
+
+
+DATA2VEC_TEXT_SELF_ATTENTION_CLASSES = {
+    "eager": Data2VecTextSelfAttention,
+}
+
+
+# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Data2VecText,BERT->DATA2VEC_TEXT
+class Data2VecTextAttention(nn.Module):
+    def __init__(self, config, position_embedding_type=None, layer_idx=None):
+        super().__init__()
+        self.self = DATA2VEC_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
+            config,
+            position_embedding_type=position_embedding_type,
+            layer_idx=layer_idx,
+        )
+        self.output = Data2VecTextSelfOutput(config)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+        )
+
+        # Prune linear layers
+        self.self.query = prune_linear_layer(self.self.query, index)
+        self.self.key = prune_linear_layer(self.self.key, index)
+        self.self.value = prune_linear_layer(self.self.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        output_attentions: Optional[bool] = False,
+        cache_position: Optional[torch.Tensor] = None,
+    ) -> tuple[torch.Tensor]:
+        self_outputs = self.self(
+            hidden_states,
+            attention_mask=attention_mask,
+            head_mask=head_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            past_key_values=past_key_values,
+            output_attentions=output_attentions,
+            cache_position=cache_position,
+        )
+        attention_output = self.output(self_outputs[0], hidden_states)
+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
+        return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate
+class Data2VecTextIntermediate(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+        return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOutput
+class Data2VecTextOutput(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Data2VecText
+class Data2VecTextLayer(GradientCheckpointingLayer):
+    def __init__(self, config, layer_idx=None):
+        super().__init__()
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.seq_len_dim = 1
+        self.attention = Data2VecTextAttention(config, layer_idx=layer_idx)
+        self.is_decoder = config.is_decoder
+        self.add_cross_attention = config.add_cross_attention
+        if self.add_cross_attention:
+            if not self.is_decoder:
+                raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+            self.crossattention = Data2VecTextAttention(
+                config, position_embedding_type="absolute", layer_idx=layer_idx
+            )
+        self.intermediate = Data2VecTextIntermediate(config)
+        self.output = Data2VecTextOutput(config)
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        output_attentions: Optional[bool] = False,
+        cache_position: Optional[torch.Tensor] = None,
+    ) -> tuple[torch.Tensor]:
+        self_attention_outputs = self.attention(
+            hidden_states,
+            attention_mask=attention_mask,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            past_key_values=past_key_values,
+            cache_position=cache_position,
+        )
+        attention_output = self_attention_outputs[0]
+        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
+
+        if self.is_decoder and encoder_hidden_states is not None:
+            if not hasattr(self, "crossattention"):
+                raise ValueError(
+                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+                    " by setting `config.add_cross_attention=True`"
+                )
+
+            cross_attention_outputs = self.crossattention(
+                attention_output,
+                attention_mask=encoder_attention_mask,
+                head_mask=head_mask,
+                encoder_hidden_states=encoder_hidden_states,
+                past_key_values=past_key_values,
+                output_attentions=output_attentions,
+                cache_position=cache_position,
+            )
+            attention_output = cross_attention_outputs[0]
+            outputs = outputs + cross_attention_outputs[1:]  # add cross attentions if we output attention weights
+
+        layer_output = apply_chunking_to_forward(
+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+        )
+        outputs = (layer_output,) + outputs
+
+        return outputs
+
+    def feed_forward_chunk(self, attention_output):
+        intermediate_output = self.intermediate(attention_output)
+        layer_output = self.output(intermediate_output, attention_output)
+        return layer_output
+
+
+# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Data2VecText
+class Data2VecTextEncoder(nn.Module):
+    def __init__(self, config, layer_idx=None):
+        super().__init__()
+        self.config = config
+        self.layer = nn.ModuleList([Data2VecTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = False,
+        output_hidden_states: Optional[bool] = False,
+        return_dict: Optional[bool] = True,
+        cache_position: Optional[torch.Tensor] = None,
+    ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        if use_cache and self.config.is_decoder and past_key_values is None:
+            past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+
+        if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
+            logger.warning_once(
+                "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+                "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
+                "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
+            )
+            past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
+
+        for i, layer_module in enumerate(self.layer):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+
+            layer_outputs = layer_module(
+                hidden_states,
+                attention_mask,
+                layer_head_mask,
+                encoder_hidden_states,  # as a positional argument for gradient checkpointing
+                encoder_attention_mask=encoder_attention_mask,
+                past_key_values=past_key_values,
+                output_attentions=output_attentions,
+                cache_position=cache_position,
+            )
+
+            hidden_states = layer_outputs[0]
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+                if self.config.add_cross_attention:
+                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [
+                    hidden_states,
+                    past_key_values,
+                    all_hidden_states,
+                    all_self_attentions,
+                    all_cross_attentions,
+                ]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=past_key_values,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPooler
+class Data2VecTextPooler(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.activation = nn.Tanh()
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(first_token_tensor)
+        pooled_output = self.activation(pooled_output)
+        return pooled_output
+
+
+@auto_docstring
+class Data2VecTextPreTrainedModel(PreTrainedModel):
+    config: Data2VecTextConfig
+    base_model_prefix = "data2vec_text"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["Data2VecTextForTextEmbeddings", "Data2VecTextLayer"]
+
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        if isinstance(module, nn.Linear):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+        elif isinstance(module, nn.LayerNorm):
+            if hasattr(module, "bias") and module.bias is not None:
+                module.bias.data.zero_()
+            if hasattr(module, "weight") and module.weight is not None:
+                module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class Data2VecTextModel(Data2VecTextPreTrainedModel):
+    """
+
+    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+    cross-attention is added between the self-attention layers, following the architecture described in *Attention is
+    all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
+    Kaiser and Illia Polosukhin.
+
+    To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+    to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+    `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+
+    .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
+
+    """
+
+    def __init__(self, config, add_pooling_layer=True):
+        r"""
+        add_pooling_layer (bool, *optional*, defaults to `True`):
+            Whether to add a pooling layer
+        """
+        super().__init__(config)
+        self.config = config
+
+        self.embeddings = Data2VecTextForTextEmbeddings(config)
+        self.encoder = Data2VecTextEncoder(config)
+
+        self.pooler = Data2VecTextPooler(config) if add_pooling_layer else None
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embeddings.word_embeddings
+
+    def set_input_embeddings(self, value):
+        self.embeddings.word_embeddings = value
+
+    def _prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        token_type_ids: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Cache] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.Tensor] = None,
+    ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if self.config.is_decoder:
+            use_cache = use_cache if use_cache is not None else self.config.use_cache
+        else:
+            use_cache = False
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+            input_shape = input_ids.size()
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        batch_size, seq_length = input_shape
+        device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+        past_key_values_length = 0
+        if past_key_values is not None:
+            past_key_values_length = (
+                past_key_values[0][0].shape[-2]
+                if not isinstance(past_key_values, Cache)
+                else past_key_values.get_seq_length()
+            )
+
+        if attention_mask is None:
+            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+        if token_type_ids is None:
+            if hasattr(self.embeddings, "token_type_ids"):
+                buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+                token_type_ids = buffered_token_type_ids_expanded
+            else:
+                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable to all heads.
+        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+        # If a 2D or 3D attention mask is provided for the cross-attention
+        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+        if self.config.is_decoder and encoder_hidden_states is not None:
+            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+            if encoder_attention_mask is None:
+                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+        else:
+            encoder_extended_attention_mask = None
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+        embedding_output = self.embeddings(
+            input_ids=input_ids,
+            position_ids=position_ids,
+            token_type_ids=token_type_ids,
+            inputs_embeds=inputs_embeds,
+            past_key_values_length=past_key_values_length,
+        )
+        encoder_outputs = self.encoder(
+            embedding_output,
+            attention_mask=extended_attention_mask,
+            head_mask=head_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_extended_attention_mask,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            cache_position=cache_position,
+        )
+        sequence_output = encoder_outputs[0]
+        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+        if not return_dict:
+            return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+        return BaseModelOutputWithPoolingAndCrossAttentions(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            past_key_values=encoder_outputs.past_key_values,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+            cross_attentions=encoder_outputs.cross_attentions,
+        )
+
+
+@auto_docstring(
+    custom_intro="""
+    Data2VecText Model with a `language modeling` head on top for CLM fine-tuning.
+    """
+)
+class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel, GenerationMixin):
+    _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        if not config.is_decoder:
+            logger.warning("If you want to use `Data2VecTextLMHeadModel` as a standalone, add `is_decoder=True.`")
+
+        self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
+        self.lm_head = Data2VecTextLMHead(config)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_output_embeddings(self):
+        return self.lm_head.decoder
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head.decoder = new_embeddings
+
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        token_type_ids: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.Tensor] = None,
+        **kwargs,
+    ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+            `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+            ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+
+        Example:
+
+        ```python
+        >>> from transformers import AutoTokenizer, Data2VecTextForCausalLM, Data2VecTextConfig
+        >>> import torch
+
+        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/data2vec-text-base")
+        >>> config = Data2VecTextConfig.from_pretrained("facebook/data2vec-text-base")
+        >>> config.is_decoder = True
+        >>> model = Data2VecTextForCausalLM.from_pretrained("facebook/data2vec-text-base", config=config)
+
+        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+        >>> outputs = model(**inputs)
+
+        >>> prediction_logits = outputs.logits
+        ```"""
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        if labels is not None:
+            use_cache = False
+
+        outputs = self.data2vec_text(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            cache_position=cache_position,
+        )
+
+        sequence_output = outputs[0]
+        prediction_scores = self.lm_head(sequence_output)
+
+        lm_loss = None
+        if labels is not None:
+            lm_loss = self.loss_function(
+                prediction_scores,
+                labels,
+                vocab_size=self.config.vocab_size,
+                **kwargs,
+            )
+
+        if not return_dict:
+            output = (prediction_scores,) + outputs[2:]
+            return ((lm_loss,) + output) if lm_loss is not None else output
+
+        return CausalLMOutputWithCrossAttentions(
+            loss=lm_loss,
+            logits=prediction_scores,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            cross_attentions=outputs.cross_attentions,
+        )
+
+
+@auto_docstring
+class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel):
+    _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        if config.is_decoder:
+            logger.warning(
+                "If you want to use `Data2VecTextForMaskedLM` make sure `config.is_decoder=False` for "
+                "bi-directional self-attention."
+            )
+
+        self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
+        self.lm_head = Data2VecTextLMHead(config)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_output_embeddings(self):
+        return self.lm_head.decoder
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head.decoder = new_embeddings
+
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        token_type_ids: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[tuple, MaskedLMOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.data2vec_text(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        sequence_output = outputs[0]
+        prediction_scores = self.lm_head(sequence_output)
+
+        masked_lm_loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+
+            labels = labels.to(prediction_scores.device)
+            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+        if not return_dict:
+            output = (prediction_scores,) + outputs[2:]
+            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+        return MaskedLMOutput(
+            loss=masked_lm_loss,
+            logits=prediction_scores,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead with Roberta->Data2VecText
+class Data2VecTextLMHead(nn.Module):
+    """Data2VecText Head for masked language modeling."""
+
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
+        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+        self.decoder.bias = self.bias
+
+    def forward(self, features, **kwargs):
+        x = self.dense(features)
+        x = gelu(x)
+        x = self.layer_norm(x)
+
+        # project back to size of vocabulary with bias
+        x = self.decoder(x)
+
+        return x
+
+    def _tie_weights(self):
+        # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
+        # For accelerate compatibility and to not break backward compatibility
+        if self.decoder.bias.device.type == "meta":
+            self.decoder.bias = self.bias
+        else:
+            self.bias = self.decoder.bias
+
+
+@auto_docstring(
+    custom_intro="""
+    Data2VecText Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+    pooled output) e.g. for GLUE tasks.
+    """
+)
+class Data2VecTextForSequenceClassification(Data2VecTextPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+        self.config = config
+
+        self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
+        self.classifier = Data2VecTextClassificationHead(config)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        token_type_ids: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[tuple, SequenceClassifierOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.data2vec_text(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        sequence_output = outputs[0]
+        logits = self.classifier(sequence_output)
+
+        loss = None
+        if labels is not None:
+            labels = labels.to(logits.device)
+
+            if self.config.problem_type is None:
+                if self.num_labels == 1:
+                    self.config.problem_type = "regression"
+                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+                    self.config.problem_type = "single_label_classification"
+                else:
+                    self.config.problem_type = "multi_label_classification"
+
+            if self.config.problem_type == "regression":
+                loss_fct = MSELoss()
+                if self.num_labels == 1:
+                    loss = loss_fct(logits.squeeze(), labels.squeeze())
+                else:
+                    loss = loss_fct(logits, labels)
+            elif self.config.problem_type == "single_label_classification":
+                loss_fct = CrossEntropyLoss()
+                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+            elif self.config.problem_type == "multi_label_classification":
+                loss_fct = BCEWithLogitsLoss()
+                loss = loss_fct(logits, labels)
+
+        if not return_dict:
+            output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return SequenceClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+@auto_docstring
+class Data2VecTextForMultipleChoice(Data2VecTextPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.data2vec_text = Data2VecTextModel(config)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+        self.classifier = nn.Linear(config.hidden_size, 1)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        token_type_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[tuple, MultipleChoiceModelOutput]:
+        r"""
+        input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
+            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+            1]`:
+
+            - 0 corresponds to a *sentence A* token,
+            - 1 corresponds to a *sentence B* token.
+
+            [What are token type IDs?](../glossary#token-type-ids)
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+            `input_ids` above)
+        position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.max_position_embeddings - 1]`.
+
+            [What are position IDs?](../glossary#position-ids)
+        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+            model's internal embedding lookup matrix.
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+        flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+        flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+        flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+        flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+        flat_inputs_embeds = (
+            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+            if inputs_embeds is not None
+            else None
+        )
+
+        outputs = self.data2vec_text(
+            flat_input_ids,
+            position_ids=flat_position_ids,
+            token_type_ids=flat_token_type_ids,
+            attention_mask=flat_attention_mask,
+            head_mask=head_mask,
+            inputs_embeds=flat_inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        pooled_output = outputs[1]
+
+        pooled_output = self.dropout(pooled_output)
+        logits = self.classifier(pooled_output)
+        reshaped_logits = logits.view(-1, num_choices)
+
+        loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+
+            labels = labels.to(reshaped_logits.device)
+            loss = loss_fct(reshaped_logits, labels)
+
+        if not return_dict:
+            output = (reshaped_logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return MultipleChoiceModelOutput(
+            loss=loss,
+            logits=reshaped_logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+@auto_docstring
+class Data2VecTextForTokenClassification(Data2VecTextPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+
+        self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
+        classifier_dropout = (
+            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+        )
+        self.dropout = nn.Dropout(classifier_dropout)
+        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        token_type_ids: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[tuple, TokenClassifierOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.data2vec_text(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        sequence_output = outputs[0]
+
+        sequence_output = self.dropout(sequence_output)
+        logits = self.classifier(sequence_output)
+
+        loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+
+            labels = labels.to(logits.device)
+            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+        if not return_dict:
+            output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return TokenClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+# Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->Data2VecText
+class Data2VecTextClassificationHead(nn.Module):
+    """Head for sentence-level classification tasks."""
+
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        classifier_dropout = (
+            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+        )
+        self.dropout = nn.Dropout(classifier_dropout)
+        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+    def forward(self, features, **kwargs):
+        x = features[:, 0, :]  # take  token (equiv. to [CLS])
+        x = self.dropout(x)
+        x = self.dense(x)
+        x = torch.tanh(x)
+        x = self.dropout(x)
+        x = self.out_proj(x)
+        return x
+
+
+@auto_docstring
+class Data2VecTextForQuestionAnswering(Data2VecTextPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+
+        self.data2vec_text = Data2VecTextModel(config, add_pooling_layer=False)
+        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        token_type_ids: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        start_positions: Optional[torch.LongTensor] = None,
+        end_positions: Optional[torch.LongTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[tuple, QuestionAnsweringModelOutput]:
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.data2vec_text(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        sequence_output = outputs[0]
+
+        logits = self.qa_outputs(sequence_output)
+        start_logits, end_logits = logits.split(1, dim=-1)
+        start_logits = start_logits.squeeze(-1).contiguous()
+        end_logits = end_logits.squeeze(-1).contiguous()
+
+        total_loss = None
+        if start_positions is not None and end_positions is not None:
+            # If we are on multi-GPU, split add a dimension
+            if len(start_positions.size()) > 1:
+                start_positions = start_positions.squeeze(-1)
+            if len(end_positions.size()) > 1:
+                end_positions = end_positions.squeeze(-1)
+            # sometimes the start/end positions are outside our model inputs, we ignore these terms
+            ignored_index = start_logits.size(1)
+            start_positions = start_positions.clamp(0, ignored_index)
+            end_positions = end_positions.clamp(0, ignored_index)
+
+            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+            start_loss = loss_fct(start_logits, start_positions)
+            end_loss = loss_fct(end_logits, end_positions)
+            total_loss = (start_loss + end_loss) / 2
+
+        if not return_dict:
+            output = (start_logits, end_logits) + outputs[2:]
+            return ((total_loss,) + output) if total_loss is not None else output
+
+        return QuestionAnsweringModelOutput(
+            loss=total_loss,
+            start_logits=start_logits,
+            end_logits=end_logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
+    """
+    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
+    are ignored. This is modified from fairseq's `utils.make_positions`.
+
+    Args:
+        x: torch.Tensor x:
+
+    Returns: torch.Tensor
+    """
+    # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+    mask = input_ids.ne(padding_idx).int()
+    incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
+    return incremental_indices.long() + padding_idx
+
+
+__all__ = [
+    "Data2VecTextForCausalLM",
+    "Data2VecTextForMaskedLM",
+    "Data2VecTextForMultipleChoice",
+    "Data2VecTextForQuestionAnswering",
+    "Data2VecTextForSequenceClassification",
+    "Data2VecTextForTokenClassification",
+    "Data2VecTextModel",
+    "Data2VecTextPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_vision.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..f214f8eb6a0bcefe5eb6a7072864dfbbe4b0b22a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_data2vec_vision.py
@@ -0,0 +1,1348 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Data2VecVision model."""
+
+import collections.abc
+import math
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+    BaseModelOutput,
+    BaseModelOutputWithPooling,
+    ImageClassifierOutput,
+    SemanticSegmenterOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import compile_compatible_method_lru_cache, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import auto_docstring, logging, torch_int
+from .configuration_data2vec_vision import Data2VecVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+    custom_intro="""
+    Class for outputs of [`Data2VecVisionModel`].
+    """
+)
+# Copied from transformers.models.beit.modeling_beit.BeitModelOutputWithPooling with Beit->Data2VecVision
+class Data2VecVisionModelOutputWithPooling(BaseModelOutputWithPooling):
+    r"""
+    pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+        Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
+        *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
+        will be returned.
+    """
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+    """
+    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+    argument.
+    """
+    if drop_prob == 0.0 or not training:
+        return input
+    keep_prob = 1 - drop_prob
+    shape = (input.shape[0],) + (1,) * (input.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+    random_tensor.floor_()  # binarize
+    output = input.div(keep_prob) * random_tensor
+    return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Data2VecVision
+class Data2VecVisionDropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+    def __init__(self, drop_prob: Optional[float] = None) -> None:
+        super().__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        return drop_path(hidden_states, self.drop_prob, self.training)
+
+    def extra_repr(self) -> str:
+        return f"p={self.drop_prob}"
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitEmbeddings with Beit->Data2VecVision
+class Data2VecVisionEmbeddings(nn.Module):
+    """
+    Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
+
+    """
+
+    def __init__(self, config: Data2VecVisionConfig) -> None:
+        super().__init__()
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+        if config.use_mask_token:
+            self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+        else:
+            self.mask_token = None
+        self.patch_embeddings = Data2VecVisionPatchEmbeddings(config)
+        self.patch_size = config.patch_size
+        self.image_size = (
+            config.image_size
+            if isinstance(config.image_size, collections.abc.Iterable)
+            else (config.image_size, config.image_size)
+        )
+        num_patches = self.patch_embeddings.num_patches
+        if config.use_absolute_position_embeddings:
+            self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
+        else:
+            self.position_embeddings = None
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
+    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+        """
+        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+        images. This method is also adapted to support torch.jit tracing.
+
+        Adapted from:
+        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+        """
+
+        num_patches = embeddings.shape[1] - 1
+        num_positions = self.position_embeddings.shape[1] - 1
+
+        # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+            return self.position_embeddings
+
+        class_pos_embed = self.position_embeddings[:, :1]
+        patch_pos_embed = self.position_embeddings[:, 1:]
+
+        dim = embeddings.shape[-1]
+
+        new_height = height // self.patch_size
+        new_width = width // self.patch_size
+
+        sqrt_num_positions = torch_int(num_positions**0.5)
+        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed,
+            size=(new_height, new_width),
+            mode="bicubic",
+            align_corners=False,
+        )
+
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+        return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+    def forward(
+        self,
+        pixel_values: torch.Tensor,
+        bool_masked_pos: Optional[torch.BoolTensor] = None,
+        interpolate_pos_encoding: Optional[bool] = None,
+    ) -> torch.Tensor:
+        if self.position_embeddings is not None and interpolate_pos_encoding is not None:
+            warnings.warn(
+                "`interpolate_pos_encoding` argument has no effect for BEiTEmbeddings, embeddings are always "
+                "interpolated to the input image size. The argument will be removed in transformers v4.51.0."
+            )
+
+        _, _, height, width = pixel_values.shape
+        embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values)
+        batch_size, seq_len, _ = embeddings.size()
+
+        if bool_masked_pos is not None:
+            mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
+            # replace the masked visual tokens by mask_tokens
+            w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+            embeddings = embeddings * (1 - w) + mask_tokens * w
+
+        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+        embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+        if self.position_embeddings is not None:
+            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+
+        embeddings = self.dropout(embeddings)
+
+        return embeddings, (patch_height, patch_width)
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitPatchEmbeddings with Beit->Data2VecVision
+class Data2VecVisionPatchEmbeddings(nn.Module):
+    """
+    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+    Transformer.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        image_size, patch_size = config.image_size, config.patch_size
+        num_channels, hidden_size = config.num_channels, config.hidden_size
+
+        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+        patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.num_patches = num_patches
+        self.patch_shape = patch_shape
+
+        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+        batch_size, num_channels, height, width = pixel_values.shape
+        if num_channels != self.num_channels:
+            raise ValueError(
+                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+            )
+
+        embeddings = self.projection(pixel_values)
+        patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
+        embeddings = embeddings.flatten(2).transpose(1, 2)
+
+        return embeddings, (patch_height, patch_width)
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitSelfAttention with Beit->Data2VecVision
+class Data2VecVisionSelfAttention(nn.Module):
+    def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
+        super().__init__()
+        self.config = config
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+            raise ValueError(
+                f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
+                f"heads {config.num_attention_heads}."
+            )
+
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+        self.query = nn.Linear(config.hidden_size, self.all_head_size)
+        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
+        self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+        self.has_relative_position_bias = bool(window_size)
+        if self.has_relative_position_bias:
+            self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+        relative_position_bias: Optional[torch.Tensor] = None,
+        interpolate_pos_encoding: bool = False,
+        resolution: Optional[tuple[int]] = None,
+    ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
+        batch_size, seq_length, _ = hidden_states.shape
+        query_layer = (
+            self.query(hidden_states)
+            .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+            .transpose(1, 2)
+        )
+        key_layer = (
+            self.key(hidden_states)
+            .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+            .transpose(1, 2)
+        )
+        value_layer = (
+            self.value(hidden_states)
+            .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+            .transpose(1, 2)
+        )
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+        # Add relative position bias if present.
+        if self.has_relative_position_bias:
+            height, width = resolution
+            window_size = (height // self.config.patch_size, width // self.config.patch_size)
+            attention_scores = attention_scores + self.relative_position_bias(
+                window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
+            )
+
+        # Add shared relative position bias if provided.
+        if relative_position_bias is not None:
+            attention_scores = attention_scores + relative_position_bias
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs = self.dropout(attention_probs)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs = attention_probs * head_mask
+
+        context_layer = torch.matmul(attention_probs, value_layer)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(*new_context_layer_shape)
+
+        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+        return outputs
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitSdpaSelfAttention with Beit->Data2VecVision
+class Data2VecVisionSdpaSelfAttention(Data2VecVisionSelfAttention):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+        relative_position_bias: Optional[torch.Tensor] = None,
+        interpolate_pos_encoding: bool = False,
+        resolution: Optional[tuple[int]] = None,
+    ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
+        if output_attentions or head_mask is not None:
+            logger.warning_once(
+                "`Data2VecVisionSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not "
+                "support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, "
+                "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
+                'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+            )
+            return super().forward(
+                hidden_states=hidden_states,
+                head_mask=head_mask,
+                output_attentions=output_attentions,
+                relative_position_bias=relative_position_bias,
+                interpolate_pos_encoding=interpolate_pos_encoding,
+                resolution=resolution,
+            )
+
+        batch_size, seq_length, _ = hidden_states.shape
+        query_layer = (
+            self.query(hidden_states)
+            .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+            .transpose(1, 2)
+        )
+        key_layer = (
+            self.key(hidden_states)
+            .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+            .transpose(1, 2)
+        )
+        value_layer = (
+            self.value(hidden_states)
+            .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+            .transpose(1, 2)
+        )
+
+        attn_bias = None
+        if self.has_relative_position_bias:
+            height, width = resolution
+            window_size = (height // self.config.patch_size, width // self.config.patch_size)
+            attn_bias = self.relative_position_bias(
+                window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
+            )
+
+        # Add shared relative position bias if provided.
+        if relative_position_bias is not None:
+            if attn_bias is None:
+                attn_bias = relative_position_bias
+            else:
+                attn_bias += relative_position_bias
+
+        scaling = 1 / math.sqrt(self.attention_head_size)
+        context_layer = torch.nn.functional.scaled_dot_product_attention(
+            query_layer,
+            key_layer,
+            value_layer,
+            attn_mask=attn_bias,
+            dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0,
+            is_causal=False,
+            scale=scaling,
+        )
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(*new_context_layer_shape)
+        return context_layer, None
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitSelfOutput with Beit->Data2VecVision
+class Data2VecVisionSelfOutput(nn.Module):
+    """
+    The residual connection is defined in Data2VecVisionLayer instead of here (as is the case with other models), due to the
+    layernorm applied before each block.
+    """
+
+    def __init__(self, config: Data2VecVisionConfig) -> None:
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+
+        return hidden_states
+
+
+DATA2VEC_VISION_SELF_ATTENTION_CLASSES = {
+    "eager": Data2VecVisionSelfAttention,
+    "sdpa": Data2VecVisionSdpaSelfAttention,
+}
+
+
+# Copied from tests.models.beit.modeling_beit.BeitAttention with Beit->Data2VecVision, BEIT->DATA2VEC_VISION
+class Data2VecVisionAttention(nn.Module):
+    def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
+        super().__init__()
+        self.attention = DATA2VEC_VISION_SELF_ATTENTION_CLASSES[config._attn_implementation](
+            config, window_size=window_size
+        )
+        self.output = Data2VecVisionSelfOutput(config)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+        )
+
+        # Prune linear layers
+        self.attention.query = prune_linear_layer(self.attention.query, index)
+        self.attention.key = prune_linear_layer(self.attention.key, index)
+        self.attention.value = prune_linear_layer(self.attention.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+        relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
+        interpolate_pos_encoding: bool = False,
+        resolution: Optional[tuple[int]] = None,
+    ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
+        self_outputs = self.attention(
+            hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution
+        )
+
+        attention_output = self.output(self_outputs[0], hidden_states)
+
+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
+        return outputs
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitIntermediate with Beit->Data2VecVision
+class Data2VecVisionIntermediate(nn.Module):
+    def __init__(self, config: Data2VecVisionConfig) -> None:
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+
+        return hidden_states
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitOutput with Beit->Data2VecVision
+class Data2VecVisionOutput(nn.Module):
+    def __init__(self, config: Data2VecVisionConfig) -> None:
+        super().__init__()
+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+
+        return hidden_states
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitLayer with Beit->Data2VecVision,BEiT->Data2VecVision
+class Data2VecVisionLayer(GradientCheckpointingLayer):
+    """This corresponds to the Block class in the timm implementation."""
+
+    def __init__(
+        self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0
+    ) -> None:
+        super().__init__()
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.seq_len_dim = 1
+        self.attention = Data2VecVisionAttention(config, window_size=window_size)
+        self.intermediate = Data2VecVisionIntermediate(config)
+        self.output = Data2VecVisionOutput(config)
+        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.drop_path = Data2VecVisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
+        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+        init_values = config.layer_scale_init_value
+        if init_values > 0:
+            self.lambda_1 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
+            self.lambda_2 = nn.Parameter(init_values * torch.ones(config.hidden_size), requires_grad=True)
+        else:
+            self.lambda_1, self.lambda_2 = None, None
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+        relative_position_bias: Optional[torch.Tensor] = None,
+        interpolate_pos_encoding: bool = False,
+        resolution: Optional[tuple[int, int]] = None,
+    ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
+        self_attention_outputs = self.attention(
+            self.layernorm_before(hidden_states),  # in Data2VecVision, layernorm is applied before self-attention
+            head_mask,
+            output_attentions=output_attentions,
+            relative_position_bias=relative_position_bias,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+            resolution=resolution,
+        )
+        attention_output = self_attention_outputs[0]
+        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
+
+        # apply lambda_1 if present
+        if self.lambda_1 is not None:
+            attention_output = self.lambda_1 * attention_output
+
+        # first residual connection
+        hidden_states = self.drop_path(attention_output) + hidden_states
+
+        # in Data2VecVision, layernorm is also applied after self-attention
+        layer_output = self.layernorm_after(hidden_states)
+
+        layer_output = self.intermediate(layer_output)
+        layer_output = self.output(layer_output)
+
+        if self.lambda_2 is not None:
+            layer_output = self.lambda_2 * layer_output
+
+        # second residual connection
+        layer_output = self.drop_path(layer_output) + hidden_states
+
+        outputs = (layer_output,) + outputs
+
+        return outputs
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitRelativePositionBias with Beit->Data2VecVision
+class Data2VecVisionRelativePositionBias(nn.Module):
+    def __init__(self, config: Data2VecVisionConfig, window_size: tuple) -> None:
+        super().__init__()
+        self.window_size = window_size
+        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros(self.num_relative_distance, config.num_attention_heads)
+        )  # 2*Wh-1 * 2*Ww-1, nH
+        # cls to token & token 2 cls & cls to cls
+
+    @compile_compatible_method_lru_cache(maxsize=10)
+    def generate_relative_position_index(self, window_size: tuple[int, int]) -> torch.Tensor:
+        """
+        This method creates the relative position index, modified to support arbitrary window sizes,
+        as introduced in [MiDaS v3.1](https://huggingface.co/papers/2307.14460).
+        """
+        num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+        # cls to token & token 2 cls & cls to cls
+        # get pair-wise relative position index for each token inside the window
+        window_area = window_size[0] * window_size[1]
+        grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij")
+        coords = torch.stack(grid)  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+        relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
+        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        relative_position_index[0, 0:] = num_relative_distance - 3
+        relative_position_index[0:, 0] = num_relative_distance - 2
+        relative_position_index[0, 0] = num_relative_distance - 1
+        return relative_position_index
+
+    def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor:
+        """
+        Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
+        """
+        old_height = 2 * self.window_size[0] - 1
+        old_width = 2 * self.window_size[1] - 1
+
+        new_height = 2 * window_size[0] - 1
+        new_width = 2 * window_size[1] - 1
+
+        old_relative_position_bias_table = self.relative_position_bias_table
+
+        old_num_relative_distance = self.num_relative_distance
+        new_num_relative_distance = new_height * new_width + 3
+
+        old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3]
+
+        old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
+        new_sub_table = nn.functional.interpolate(
+            old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
+        )
+        new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
+
+        new_relative_position_bias_table = torch.cat(
+            [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
+        )
+
+        relative_position_index = self.generate_relative_position_index(window_size)
+        relative_position_bias = new_relative_position_bias_table[relative_position_index.view(-1)]
+
+        # patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
+        relative_position_bias = relative_position_bias.view(
+            window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
+        )
+        # num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height
+        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
+
+        if interpolate_pos_encoding:
+            relative_position_bias = nn.functional.interpolate(
+                relative_position_bias.unsqueeze(1),
+                size=(dim_size, dim_size),
+                mode="bilinear",
+                align_corners=False,
+            ).squeeze(1)
+
+        return relative_position_bias.unsqueeze(0)
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision
+class Data2VecVisionEncoder(nn.Module):
+    def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
+        super().__init__()
+        self.config = config
+        self.has_relative_position_bias = config.use_shared_relative_position_bias
+        if self.has_relative_position_bias:
+            self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
+
+        # stochastic depth decay rule
+        dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
+        self.layer = nn.ModuleList(
+            [
+                Data2VecVisionLayer(
+                    config,
+                    window_size=window_size if config.use_relative_position_bias else None,
+                    drop_path_rate=dpr[i],
+                )
+                for i in range(config.num_hidden_layers)
+            ]
+        )
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+        output_hidden_states: bool = False,
+        interpolate_pos_encoding: bool = False,
+        resolution: Optional[tuple[int, int]] = None,
+        return_dict: bool = True,
+    ) -> Union[tuple, BaseModelOutput]:
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+
+        for i, layer_module in enumerate(self.layer):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            if self.has_relative_position_bias:
+                height, width = resolution
+                window_size = (height // self.config.patch_size, width // self.config.patch_size)
+                relative_position_bias = self.relative_position_bias(
+                    window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
+                )
+            else:
+                relative_position_bias = None
+
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+
+            layer_outputs = layer_module(
+                hidden_states,
+                head_mask=layer_head_mask,
+                output_attentions=output_attentions,
+                relative_position_bias=relative_position_bias,
+                interpolate_pos_encoding=interpolate_pos_encoding,
+                resolution=resolution,
+            )
+
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+        return BaseModelOutput(
+            last_hidden_state=hidden_states,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+        )
+
+
+@auto_docstring
+# Copied from transformers.models.beit.modeling_beit.BeitPreTrainedModel with Beit->Data2VecVision,beit->data2vec_vision
+class Data2VecVisionPreTrainedModel(PreTrainedModel):
+    config: Data2VecVisionConfig
+    base_model_prefix = "data2vec_vision"
+    main_input_name = "pixel_values"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["Data2VecVisionLayer"]
+    _keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
+    _supports_sdpa = True
+
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, Data2VecVisionEmbeddings):
+            module.cls_token.data.zero_()
+            if module.mask_token is not None:
+                module.mask_token.data.zero_()
+            if module.position_embeddings is not None:
+                module.position_embeddings.data.zero_()
+        elif isinstance(module, Data2VecVisionRelativePositionBias):
+            module.relative_position_bias_table.data.zero_()
+        elif isinstance(module, Data2VecVisionLayer):
+            if module.lambda_1 is not None:
+                module.lambda_1.data.fill_(self.config.layer_scale_init_value)
+                module.lambda_2.data.fill_(self.config.layer_scale_init_value)
+
+
+@auto_docstring
+# Copied from transformers.models.beit.modeling_beit.BeitModel with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,True->False
+class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
+    def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = False) -> None:
+        r"""
+        add_pooling_layer (bool, *optional*, defaults to `False`):
+            Whether to add a pooling layer
+        """
+        super().__init__(config)
+        self.config = config
+
+        self.embeddings = Data2VecVisionEmbeddings(config)
+        self.encoder = Data2VecVisionEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape)
+
+        self.layernorm = (
+            nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        )
+        self.pooler = Data2VecVisionPooler(config) if add_pooling_layer else None
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embeddings.patch_embeddings
+
+    def _prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    @auto_docstring
+    def forward(
+        self,
+        pixel_values: torch.Tensor,
+        bool_masked_pos: Optional[torch.BoolTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        interpolate_pos_encoding: bool = False,
+        return_dict: Optional[bool] = None,
+    ) -> Union[tuple, Data2VecVisionModelOutputWithPooling]:
+        r"""
+        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
+            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+        embedding_output, _ = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
+        resolution = pixel_values.shape[2:]
+
+        encoder_outputs = self.encoder(
+            embedding_output,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            resolution=resolution,
+            return_dict=return_dict,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+        )
+        sequence_output = encoder_outputs[0]
+        sequence_output = self.layernorm(sequence_output)
+        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+        if not return_dict:
+            head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+            return head_outputs + encoder_outputs[1:]
+
+        return Data2VecVisionModelOutputWithPooling(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+        )
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitPooler with Beit->Data2VecVision
+class Data2VecVisionPooler(nn.Module):
+    def __init__(self, config: Data2VecVisionConfig) -> None:
+        super().__init__()
+        self.layernorm = (
+            nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None
+        )
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        if self.layernorm is not None:
+            # Mean pool the final hidden states of the patch tokens
+            patch_tokens = hidden_states[:, 1:, :]
+            pooled_output = self.layernorm(patch_tokens.mean(1))
+        else:
+            # Pool by simply taking the final hidden state of the [CLS] token
+            pooled_output = hidden_states[:, 0]
+
+        return pooled_output
+
+
+@auto_docstring(
+    custom_intro="""
+    Data2VecVision Model transformer with an image classification head on top (a linear layer on top of the average of
+    the final hidden states of the patch tokens) e.g. for ImageNet.
+    """
+)
+# Copied from transformers.models.beit.modeling_beit.BeitForImageClassification with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,beit->data2vec_vision
+class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel):
+    def __init__(self, config: Data2VecVisionConfig) -> None:
+        super().__init__(config)
+
+        self.num_labels = config.num_labels
+        self.data2vec_vision = Data2VecVisionModel(config, add_pooling_layer=True)
+
+        # Classifier head
+        self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @auto_docstring
+    def forward(
+        self,
+        pixel_values: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        labels: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        interpolate_pos_encoding: bool = False,
+        return_dict: Optional[bool] = None,
+    ) -> Union[tuple, ImageClassifierOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        outputs = self.data2vec_vision(
+            pixel_values,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+            return_dict=return_dict,
+        )
+
+        pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+        logits = self.classifier(pooled_output)
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(labels, logits, self.config)
+
+        if not return_dict:
+            output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return ImageClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitConvModule with Beit->Data2VecVision
+class Data2VecVisionConvModule(nn.Module):
+    """
+    A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
+    layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+
+    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: Union[int, tuple[int, int]],
+        padding: Union[int, tuple[int, int], str] = 0,
+        bias: bool = False,
+        dilation: Union[int, tuple[int, int]] = 1,
+    ) -> None:
+        super().__init__()
+        self.conv = nn.Conv2d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            padding=padding,
+            bias=bias,
+            dilation=dilation,
+        )
+        self.bn = nn.BatchNorm2d(out_channels)
+        self.activation = nn.ReLU()
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        output = self.conv(input)
+        output = self.bn(output)
+        output = self.activation(output)
+
+        return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingBlock with Beit->Data2VecVision
+class Data2VecVisionPyramidPoolingBlock(nn.Module):
+    def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
+        super().__init__()
+        self.layers = [
+            nn.AdaptiveAvgPool2d(pool_scale),
+            Data2VecVisionConvModule(in_channels, channels, kernel_size=1),
+        ]
+        for i, layer in enumerate(self.layers):
+            self.add_module(str(i), layer)
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        hidden_state = input
+        for layer in self.layers:
+            hidden_state = layer(hidden_state)
+        return hidden_state
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingModule with Beit->Data2VecVision
+class Data2VecVisionPyramidPoolingModule(nn.Module):
+    """
+    Pyramid Pooling Module (PPM) used in PSPNet.
+
+    Args:
+        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+            Module.
+        in_channels (int): Input channels.
+        channels (int): Channels after modules, before conv_seg.
+        align_corners (bool): align_corners argument of F.interpolate.
+
+    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+    """
+
+    def __init__(self, pool_scales: tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None:
+        super().__init__()
+        self.pool_scales = pool_scales
+        self.align_corners = align_corners
+        self.in_channels = in_channels
+        self.channels = channels
+        self.blocks = []
+        for i, pool_scale in enumerate(pool_scales):
+            block = Data2VecVisionPyramidPoolingBlock(
+                pool_scale=pool_scale, in_channels=in_channels, channels=channels
+            )
+            self.blocks.append(block)
+            self.add_module(str(i), block)
+
+    def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
+        ppm_outs = []
+        for ppm in self.blocks:
+            ppm_out = ppm(x)
+            upsampled_ppm_out = nn.functional.interpolate(
+                ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
+            )
+            ppm_outs.append(upsampled_ppm_out)
+        return ppm_outs
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitUperHead with Beit->Data2VecVision
+class Data2VecVisionUperHead(nn.Module):
+    """
+    Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
+    [UPerNet](https://huggingface.co/papers/1807.10221).
+
+    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+    """
+
+    def __init__(self, config: Data2VecVisionConfig) -> None:
+        super().__init__()
+
+        self.pool_scales = config.pool_scales  # e.g. (1, 2, 3, 6)
+        self.in_channels = [config.hidden_size] * 4  # e.g. [768, 768, 768, 768]
+        self.channels = config.hidden_size
+        self.align_corners = False
+        self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
+
+        # PSP Module
+        self.psp_modules = Data2VecVisionPyramidPoolingModule(
+            self.pool_scales,
+            self.in_channels[-1],
+            self.channels,
+            align_corners=self.align_corners,
+        )
+        self.bottleneck = Data2VecVisionConvModule(
+            self.in_channels[-1] + len(self.pool_scales) * self.channels,
+            self.channels,
+            kernel_size=3,
+            padding=1,
+        )
+        # FPN Module
+        self.lateral_convs = nn.ModuleList()
+        self.fpn_convs = nn.ModuleList()
+        for in_channels in self.in_channels[:-1]:  # skip the top layer
+            l_conv = Data2VecVisionConvModule(in_channels, self.channels, kernel_size=1)
+            fpn_conv = Data2VecVisionConvModule(self.channels, self.channels, kernel_size=3, padding=1)
+            self.lateral_convs.append(l_conv)
+            self.fpn_convs.append(fpn_conv)
+
+        self.fpn_bottleneck = Data2VecVisionConvModule(
+            len(self.in_channels) * self.channels,
+            self.channels,
+            kernel_size=3,
+            padding=1,
+        )
+
+    def psp_forward(self, inputs):
+        x = inputs[-1]
+        psp_outs = [x]
+        psp_outs.extend(self.psp_modules(x))
+        psp_outs = torch.cat(psp_outs, dim=1)
+        output = self.bottleneck(psp_outs)
+
+        return output
+
+    def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+        # build laterals
+        laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
+
+        laterals.append(self.psp_forward(encoder_hidden_states))
+
+        # build top-down path
+        used_backbone_levels = len(laterals)
+        for i in range(used_backbone_levels - 1, 0, -1):
+            prev_shape = laterals[i - 1].shape[2:]
+            laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
+                laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
+            )
+
+        # build outputs
+        fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
+        # append psp feature
+        fpn_outs.append(laterals[-1])
+
+        for i in range(used_backbone_levels - 1, 0, -1):
+            fpn_outs[i] = nn.functional.interpolate(
+                fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
+            )
+        fpn_outs = torch.cat(fpn_outs, dim=1)
+        output = self.fpn_bottleneck(fpn_outs)
+        output = self.classifier(output)
+
+        return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitFCNHead with Beit->Data2VecVision
+class Data2VecVisionFCNHead(nn.Module):
+    """
+    Fully Convolution Networks for Semantic Segmentation. This head is implemented of
+    [FCNNet](https://huggingface.co/papers/1411.4038>).
+
+    Args:
+        config (Data2VecVisionConfig): Configuration.
+        in_channels
+        kernel_size (int): The kernel size for convs in the head. Default: 3.
+        dilation (int): The dilation rate for convs in the head. Default: 1.
+
+
+    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+    """
+
+    def __init__(
+        self,
+        config: Data2VecVisionConfig,
+        in_index: int = 2,
+        kernel_size: int = 3,
+        dilation: Union[int, tuple[int, int]] = 1,
+    ) -> None:
+        super().__init__()
+        self.in_channels = config.hidden_size
+        self.channels = config.auxiliary_channels
+        self.num_convs = config.auxiliary_num_convs
+        self.concat_input = config.auxiliary_concat_input
+        self.in_index = in_index
+
+        conv_padding = (kernel_size // 2) * dilation
+        convs = []
+        convs.append(
+            Data2VecVisionConvModule(
+                self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
+            )
+        )
+        for i in range(self.num_convs - 1):
+            convs.append(
+                Data2VecVisionConvModule(
+                    self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
+                )
+            )
+        if self.num_convs == 0:
+            self.convs = nn.Identity()
+        else:
+            self.convs = nn.Sequential(*convs)
+        if self.concat_input:
+            self.conv_cat = Data2VecVisionConvModule(
+                self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
+            )
+
+        self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
+
+    def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+        # just take the relevant feature maps
+        hidden_states = encoder_hidden_states[self.in_index]
+        output = self.convs(hidden_states)
+        if self.concat_input:
+            output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
+        output = self.classifier(output)
+        return output
+
+
+@auto_docstring
+# Copied from transformers.models.beit.modeling_beit.BeitForSemanticSegmentation with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,microsoft/beit-base-finetuned-ade-640-640->facebook/data2vec-vision-base,beit->data2vec_vision
+class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel):
+    def __init__(self, config: Data2VecVisionConfig) -> None:
+        super().__init__(config)
+
+        self.num_labels = config.num_labels
+        self.data2vec_vision = Data2VecVisionModel(config, add_pooling_layer=False)
+
+        # FPNs
+        if len(self.config.out_indices) != 4:
+            raise ValueError(
+                "Data2VecVisionForSemanticSegmentation requires config.out_indices to be a list of 4 integers, "
+                "specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
+                "a base-sized architecture."
+            )
+        self.fpn1 = nn.Sequential(
+            nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
+            nn.BatchNorm2d(config.hidden_size),
+            nn.GELU(),
+            nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
+        )
+        self.fpn2 = nn.Sequential(
+            nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
+        )
+        self.fpn3 = nn.Identity()
+        self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
+
+        # Semantic segmentation head(s)
+        self.decode_head = Data2VecVisionUperHead(config)
+        self.auxiliary_head = Data2VecVisionFCNHead(config) if config.use_auxiliary_head else None
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def compute_loss(self, logits, auxiliary_logits, labels):
+        # upsample logits to the images' original size
+        upsampled_logits = nn.functional.interpolate(
+            logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
+        )
+        if auxiliary_logits is not None:
+            upsampled_auxiliary_logits = nn.functional.interpolate(
+                auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
+            )
+        # compute weighted loss
+        loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
+        main_loss = loss_fct(upsampled_logits, labels)
+        loss = main_loss
+        if auxiliary_logits is not None:
+            auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
+            loss += self.config.auxiliary_loss_weight * auxiliary_loss
+
+        return loss
+
+    @auto_docstring
+    def forward(
+        self,
+        pixel_values: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        labels: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        interpolate_pos_encoding: bool = False,
+        return_dict: Optional[bool] = None,
+    ) -> Union[tuple, SemanticSegmenterOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
+            Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
+
+        Examples:
+
+        ```python
+        >>> from transformers import AutoImageProcessor, Data2VecVisionForSemanticSegmentation
+        >>> from PIL import Image
+        >>> import requests
+
+        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw)
+
+        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/data2vec-vision-base")
+        >>> model = Data2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base")
+
+        >>> inputs = image_processor(images=image, return_tensors="pt")
+        >>> outputs = model(**inputs)
+        >>> # logits are of shape (batch_size, num_labels, height, width)
+        >>> logits = outputs.logits
+        ```"""
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+
+        if labels is not None and self.config.num_labels == 1:
+            raise ValueError("The number of labels should be greater than one")
+
+        outputs = self.data2vec_vision(
+            pixel_values,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=True,  # we need the intermediate hidden states
+            interpolate_pos_encoding=interpolate_pos_encoding,
+            return_dict=return_dict,
+        )
+
+        encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+        # only keep certain features, and reshape
+        # note that we do +1 as the encoder_hidden_states also includes the initial embeddings
+        features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
+        batch_size = pixel_values.shape[0]
+        patch_resolution = self.config.image_size // self.config.patch_size
+        features = [
+            x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features
+        ]
+
+        # apply FPNs
+        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
+        for i in range(len(features)):
+            features[i] = ops[i](features[i])
+
+        logits = self.decode_head(features)
+
+        auxiliary_logits = None
+        if self.auxiliary_head is not None:
+            auxiliary_logits = self.auxiliary_head(features)
+
+        loss = None
+        if labels is not None:
+            loss = self.compute_loss(logits, auxiliary_logits, labels)
+
+        if not return_dict:
+            if output_hidden_states:
+                output = (logits,) + outputs[1:]
+            else:
+                output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return SemanticSegmenterOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states if output_hidden_states else None,
+            attentions=outputs.attentions,
+        )
+
+
+__all__ = [
+    "Data2VecVisionForImageClassification",
+    "Data2VecVisionForSemanticSegmentation",
+    "Data2VecVisionModel",
+    "Data2VecVisionPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_tf_data2vec_vision.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_tf_data2vec_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fa0fe1f811ee1748956c2481d28b0ad1ef516e0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modeling_tf_data2vec_vision.py
@@ -0,0 +1,1723 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF 2.0 Data2Vec Vision model."""
+
+from __future__ import annotations
+
+import collections.abc
+import math
+from dataclasses import dataclass
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+    TFBaseModelOutput,
+    TFBaseModelOutputWithPooling,
+    TFSemanticSegmenterOutput,
+    TFSequenceClassifierOutput,
+)
+from ...modeling_tf_utils import (
+    TFModelInputType,
+    TFPreTrainedModel,
+    TFSequenceClassificationLoss,
+    get_initializer,
+    keras,
+    keras_serializable,
+    unpack_inputs,
+)
+from ...tf_utils import shape_list, stable_softmax
+from ...utils import (
+    add_code_sample_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    logging,
+    replace_return_docstrings,
+)
+from .configuration_data2vec_vision import Data2VecVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "Data2VecVisionConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/data2vec-vision-base"
+_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/data2vec-vision-base-ft1k"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "remote control, remote"
+
+
+@dataclass
+class TFData2VecVisionModelOutputWithPooling(TFBaseModelOutputWithPooling):
+    """
+    Class for outputs of [`TFData2VecVisionModel`].
+
+    Args:
+        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
+            Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
+            *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
+            will be returned.
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    last_hidden_state: tf.Tensor | None = None
+    pooler_output: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+class TFData2VecVisionDropPath(keras.layers.Layer):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+    References:
+        (1) github.com:rwightman/pytorch-image-models
+    """
+
+    def __init__(self, drop_path, **kwargs):
+        super().__init__(**kwargs)
+        self.drop_path = drop_path
+
+    def call(self, x, training=None):
+        if training:
+            keep_prob = 1 - self.drop_path
+            shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
+            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
+            random_tensor = tf.floor(random_tensor)
+            return (x / keep_prob) * random_tensor
+        return x
+
+
+class TFData2VecVisionEmbeddings(keras.layers.Layer):
+    """
+    Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
+
+    """
+
+    def __init__(self, config: Data2VecVisionConfig, **kwargs):
+        super().__init__(**kwargs)
+        self.config = config
+
+        self.patch_embeddings = TFData2VecVisionPatchEmbeddings(config, name="patch_embeddings")
+        self.num_patches = self.patch_embeddings.num_patches
+        self.config = config
+
+        self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+
+    def build(self, input_shape=None):
+        self.cls_token = self.add_weight(
+            shape=(1, 1, self.config.hidden_size),
+            initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
+            trainable=True,
+            name="cls_token",
+        )
+        if self.config.use_mask_token:
+            self.mask_token = self.add_weight(
+                shape=(1, 1, self.config.hidden_size),
+                initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
+                trainable=True,
+                name="mask_token",
+            )
+        else:
+            self.mask_token = None
+
+        if self.config.use_absolute_position_embeddings:
+            self.position_embeddings = self.add_weight(
+                shape=(1, self.num_patches + 1, self.config.hidden_size),
+                initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
+                trainable=True,
+                name="position_embeddings",
+            )
+        else:
+            self.position_embeddings = None
+
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "patch_embeddings", None) is not None:
+            with tf.name_scope(self.patch_embeddings.name):
+                self.patch_embeddings.build(None)
+
+    def call(self, pixel_values: tf.Tensor, bool_masked_pos: tf.Tensor | None = None) -> tf.Tensor:
+        embeddings = self.patch_embeddings(pixel_values)
+        batch_size, seq_len, projection_dim = shape_list(embeddings)
+
+        cls_tokens = tf.tile(self.cls_token, (batch_size, 1, 1))
+
+        if bool_masked_pos is not None:
+            mask_tokens = tf.broadcast_to(self.mask_token, (batch_size, seq_len, projection_dim))
+            # replace the masked visual tokens by mask_tokens
+            w = bool_masked_pos[..., None]
+            w = tf.cast(w, mask_tokens.dtype)
+            # since TF doesn't support eager tensor assignment
+            embeddings = embeddings * (1 - w) + mask_tokens * w
+
+        embeddings = tf.concat([cls_tokens, embeddings], axis=1)
+        if self.position_embeddings is not None:
+            embeddings = embeddings + self.position_embeddings
+        embeddings = self.dropout(embeddings)
+
+        return embeddings
+
+
+class TFData2VecVisionPatchEmbeddings(keras.layers.Layer):
+    """
+    Image to Patch Embedding.
+    """
+
+    def __init__(self, config: Data2VecVisionConfig, **kwargs):
+        super().__init__(**kwargs)
+        self.config = config
+
+        image_size, patch_size = config.image_size, config.patch_size
+        num_channels, hidden_size = config.num_channels, config.hidden_size
+
+        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+        patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_patches = num_patches
+        self.patch_shape = patch_shape
+        self.num_channels = num_channels
+
+        self.projection = keras.layers.Conv2D(
+            filters=hidden_size,
+            kernel_size=patch_size,
+            strides=patch_size,
+            padding="valid",
+            data_format="channels_last",
+            kernel_initializer="glorot_uniform",  # following torch.nn.Linear
+            bias_initializer="zeros",
+            name="projection",
+        )
+
+    def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
+        batch_size, num_channels, height, width = shape_list(pixel_values)
+        if tf.executing_eagerly():
+            if num_channels != self.num_channels:
+                raise ValueError(
+                    "Make sure that the channel dimension of the pixel values match with the one set in the"
+                    " configuration."
+                )
+            if height != self.image_size[0] or width != self.image_size[1]:
+                raise ValueError(
+                    f"Input image size ({height}*{width}) doesn't match model"
+                    f" ({self.image_size[0]}*{self.image_size[1]})."
+                )
+
+        # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
+        # So change the input format from `NCHW` to `NHWC`.
+        # shape = (batch_size, in_height, in_width, in_channels=num_channels)
+        pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
+
+        projection = self.projection(pixel_values)
+
+        # Change the 2D spatial dimensions to a single temporal dimension.
+        # shape = (batch_size, num_patches, out_channels=embed_dim)
+        num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
+
+        return tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "projection", None) is not None:
+            with tf.name_scope(self.projection.name):
+                self.projection.build([None, None, None, self.num_channels])
+
+
+class TFData2VecVisionSelfAttention(keras.layers.Layer):
+    def __init__(self, config: Data2VecVisionConfig, window_size: tuple | None = None, **kwargs):
+        super().__init__(**kwargs)
+
+        if config.hidden_size % config.num_attention_heads != 0:
+            raise ValueError(
+                f"The hidden size ({config.hidden_size}) is not a multiple of the number "
+                f"of attention heads ({config.num_attention_heads})"
+            )
+
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+        self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
+
+        self.query = keras.layers.Dense(
+            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
+        )
+        self.key = keras.layers.Dense(
+            units=self.all_head_size,
+            kernel_initializer=get_initializer(config.initializer_range),
+            name="key",
+            use_bias=False,
+        )
+        self.value = keras.layers.Dense(
+            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
+        )
+        self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
+
+        if window_size:
+            self.relative_position_bias = TFData2VecVisionRelativePositionBias(
+                config, window_size=window_size, name="relative_position_bias"
+            )
+        else:
+            self.relative_position_bias = None
+        self.config = config
+
+    def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
+        # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
+        tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
+
+        # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
+        return tf.transpose(tensor, perm=[0, 2, 1, 3])
+
+    def call(
+        self,
+        hidden_states: tf.Tensor,
+        head_mask: tf.Tensor,
+        output_attentions: bool,
+        relative_position_bias: TFData2VecVisionRelativePositionBias | None = None,
+        training: bool = False,
+    ) -> tuple[tf.Tensor]:
+        batch_size = shape_list(hidden_states)[0]
+        mixed_query_layer = self.query(inputs=hidden_states)
+        mixed_key_layer = self.key(inputs=hidden_states)
+        mixed_value_layer = self.value(inputs=hidden_states)
+        query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
+        key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
+        value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        # (batch size, num_heads, seq_len_q, seq_len_k)
+        attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+        attention_scores = attention_scores / self.sqrt_att_head_size
+
+        # Add relative position bias if present.
+        if self.relative_position_bias is not None:
+            # Passing `0.0` to the `relative_position_bias()` layer because otherwise Keras
+            # might complain about `Layer.call()` not being invoked properly. In this case this input
+            # i.e., 0.0 is not going to be used in any calculations so we're safe.
+            attention_scores = attention_scores + self.relative_position_bias(0.0)[None, ...]
+
+        # Add shared relative position bias if provided.
+        if relative_position_bias is not None:
+            attention_scores = attention_scores + relative_position_bias
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = stable_softmax(logits=attention_scores, axis=-1)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs = self.dropout(inputs=attention_probs, training=training)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs = tf.multiply(attention_probs, head_mask)
+
+        attention_output = tf.matmul(attention_probs, value_layer)
+        attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
+
+        # (batch_size, seq_len_q, all_head_size)
+        attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
+        outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
+
+        return outputs
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "query", None) is not None:
+            with tf.name_scope(self.query.name):
+                self.query.build([None, None, self.config.hidden_size])
+        if getattr(self, "key", None) is not None:
+            with tf.name_scope(self.key.name):
+                self.key.build([None, None, self.config.hidden_size])
+        if getattr(self, "value", None) is not None:
+            with tf.name_scope(self.value.name):
+                self.value.build([None, None, self.config.hidden_size])
+        if getattr(self, "relative_position_bias", None) is not None:
+            with tf.name_scope(self.relative_position_bias.name):
+                self.relative_position_bias.build(None)
+
+
+class TFData2VecVisionSelfOutput(keras.layers.Layer):
+    """
+    The residual connection is defined in TFData2VecVisionLayer instead of here (as is the case with other models), due
+    to the layernorm applied before each block.
+    """
+
+    def __init__(self, config: Data2VecVisionConfig, **kwargs):
+        super().__init__(**kwargs)
+
+        self.dense = keras.layers.Dense(
+            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+        )
+        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+        self.config = config
+
+    def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, gamma=None, training: bool = False) -> tf.Tensor:
+        hidden_states = self.dense(inputs=hidden_states)
+        hidden_states = self.dropout(inputs=hidden_states, training=training)
+
+        return hidden_states
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "dense", None) is not None:
+            with tf.name_scope(self.dense.name):
+                self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFData2VecVisionAttention(keras.layers.Layer):
+    def __init__(self, config: Data2VecVisionConfig, window_size: tuple | None = None, **kwargs):
+        super().__init__(**kwargs)
+
+        self.attention = TFData2VecVisionSelfAttention(config, window_size=window_size, name="attention")
+        self.dense_output = TFData2VecVisionSelfOutput(config, name="output")
+
+    def prune_heads(self, heads):
+        raise NotImplementedError
+
+    def call(
+        self,
+        input_tensor: tf.Tensor,
+        head_mask: tf.Tensor,
+        output_attentions: bool,
+        relative_position_bias: TFData2VecVisionRelativePositionBias | None = None,
+        training: bool = False,
+    ) -> tuple[tf.Tensor]:
+        self_outputs = self.attention(
+            hidden_states=input_tensor,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            relative_position_bias=relative_position_bias,
+            training=training,
+        )
+        attention_output = self.dense_output(
+            hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
+        )
+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
+
+        return outputs
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "attention", None) is not None:
+            with tf.name_scope(self.attention.name):
+                self.attention.build(None)
+        if getattr(self, "dense_output", None) is not None:
+            with tf.name_scope(self.dense_output.name):
+                self.dense_output.build(None)
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->Data2VecVision
+class TFData2VecVisionIntermediate(keras.layers.Layer):
+    def __init__(self, config: Data2VecVisionConfig, **kwargs):
+        super().__init__(**kwargs)
+
+        self.dense = keras.layers.Dense(
+            units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+        )
+
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = get_tf_activation(config.hidden_act)
+        else:
+            self.intermediate_act_fn = config.hidden_act
+        self.config = config
+
+    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+        hidden_states = self.dense(inputs=hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+
+        return hidden_states
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "dense", None) is not None:
+            with tf.name_scope(self.dense.name):
+                self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFData2VecVisionOutput(keras.layers.Layer):
+    def __init__(self, config: Data2VecVisionConfig, **kwargs):
+        super().__init__(**kwargs)
+
+        self.dense = keras.layers.Dense(
+            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+        )
+        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+        self.config = config
+
+    def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+        hidden_states = self.dense(inputs=hidden_states)
+        hidden_states = self.dropout(inputs=hidden_states, training=training)
+
+        return hidden_states
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "dense", None) is not None:
+            with tf.name_scope(self.dense.name):
+                self.dense.build([None, None, self.config.intermediate_size])
+
+
+class TFData2VecVisionLayer(keras.layers.Layer):
+    """This corresponds to the Block class in the timm implementation."""
+
+    def __init__(
+        self, config: Data2VecVisionConfig, window_size: tuple | None = None, drop_path_rate: float = 0.0, **kwargs
+    ):
+        super().__init__(**kwargs)
+        self.config = config
+
+        self.attention = TFData2VecVisionAttention(config, window_size=window_size, name="attention")
+        self.intermediate = TFData2VecVisionIntermediate(config, name="intermediate")
+        self.data2vec_output = TFData2VecVisionOutput(config, name="output")
+
+        self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
+        self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
+        # Using `layers.Activation` instead of `tf.identity` to better control `training`
+        # behaviour.
+        self.drop_path = (
+            TFData2VecVisionDropPath(drop_path_rate, name="drop_path")
+            if drop_path_rate > 0.0
+            else keras.layers.Activation("linear", name="drop_path")
+        )
+        self.init_values = config.layer_scale_init_value
+
+    def build(self, input_shape: tf.TensorShape = None):
+        if self.init_values > 0:
+            self.lambda_1 = self.add_weight(
+                shape=(self.config.hidden_size),
+                initializer="ones",
+                trainable=True,
+                name="lambda_1",
+            )
+            self.lambda_2 = self.add_weight(
+                shape=(self.config.hidden_size),
+                initializer="ones",
+                trainable=True,
+                name="lambda_2",
+            )
+            self.lambda_1.assign(self.init_values * tf.ones(self.config.hidden_size))
+            self.lambda_2.assign(self.init_values * tf.ones(self.config.hidden_size))
+        else:
+            self.lambda_1, self.lambda_2 = None, None
+
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "attention", None) is not None:
+            with tf.name_scope(self.attention.name):
+                self.attention.build(None)
+        if getattr(self, "intermediate", None) is not None:
+            with tf.name_scope(self.intermediate.name):
+                self.intermediate.build(None)
+        if getattr(self, "data2vec_output", None) is not None:
+            with tf.name_scope(self.data2vec_output.name):
+                self.data2vec_output.build(None)
+        if getattr(self, "layernorm_before", None) is not None:
+            with tf.name_scope(self.layernorm_before.name):
+                self.layernorm_before.build([None, None, self.config.hidden_size])
+        if getattr(self, "layernorm_after", None) is not None:
+            with tf.name_scope(self.layernorm_after.name):
+                self.layernorm_after.build([None, None, self.config.hidden_size])
+        if getattr(self, "drop_path", None) is not None:
+            with tf.name_scope(self.drop_path.name):
+                self.drop_path.build(None)
+
+    def call(
+        self,
+        hidden_states: tf.Tensor,
+        head_mask: tf.Tensor,
+        output_attentions: bool,
+        relative_position_bias: TFData2VecVisionRelativePositionBias | None = None,
+        training: bool = False,
+    ) -> tuple[tf.Tensor]:
+        self_attention_outputs = self.attention(
+            # in Data2VecVision, layernorm is applied before self-attention
+            input_tensor=self.layernorm_before(inputs=hidden_states),
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            relative_position_bias=relative_position_bias,
+            training=training,
+        )
+        attention_output = self_attention_outputs[0]
+        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
+
+        # apply lambda_1 if present
+        if self.lambda_1 is not None:
+            attention_output = self.lambda_1 * attention_output
+
+        # first residual connection
+        hidden_states = self.drop_path(attention_output) + hidden_states
+
+        # in Data2VecVision, layernorm is also applied after self-attention
+        layer_output = self.layernorm_after(hidden_states)
+
+        layer_output = self.intermediate(layer_output)
+        layer_output = self.data2vec_output(layer_output)
+
+        if self.lambda_2 is not None:
+            layer_output = self.lambda_2 * layer_output
+
+        # second residual connection
+        layer_output = self.drop_path(layer_output) + hidden_states
+
+        outputs = (layer_output,) + outputs
+
+        return outputs
+
+
+# Taken and modified from here:
+# https://github.com/leondgarse/keras_cv_attention_models/blob/main/keras_cv_attention_models/beit/beit.py#L28
+class TFData2VecVisionRelativePositionBias(keras.layers.Layer):
+    def __init__(self, config: Data2VecVisionConfig, window_size: tuple, **kwargs) -> None:
+        super().__init__(**kwargs)
+        self.config = config
+
+        self.window_size = window_size
+        # +3 for cls_token_pos_len
+        # window_size can be something like (14, 14)
+        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+
+        self.relative_position_index = self.get_position_index()
+
+    def build(self, input_shape):
+        self.relative_position_bias_table = self.add_weight(
+            shape=(self.num_relative_distance, self.config.num_attention_heads),
+            initializer="zeros",
+            trainable=True,
+            name="relative_position_bias_table",
+        )  # [2*Wh-1 * 2*Ww-1, nH]
+        # cls to token & token 2 cls & cls to cls
+
+        super().build(input_shape)
+
+    def get_position_index(self):
+        # get pair-wise relative position index for each token inside the window
+        xx, yy = tf.meshgrid(range(self.window_size[0]), range(self.window_size[1]))
+        coords = tf.stack([yy, xx], axis=0)  # [2, Wh, Ww]
+        coords_flatten = tf.reshape(coords, [2, -1])  # [2, Wh*Ww]
+
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Wh*Ww, Wh*Ww]
+        relative_coords = tf.transpose(relative_coords, perm=[1, 2, 0])  # [Wh*Ww, Wh*Ww, 2]
+
+        xx = (relative_coords[:, :, 0] + self.window_size[0] - 1) * (2 * self.window_size[1] - 1)
+        yy = relative_coords[:, :, 1] + self.window_size[1] - 1
+        relative_coords = tf.stack([xx, yy], axis=-1)
+
+        relative_position_index = tf.reduce_sum(relative_coords, axis=-1)  # [Wh*Ww, Wh*Ww]
+
+        top = tf.ones((1, relative_position_index.shape[1]), dtype=relative_position_index.dtype) * (
+            self.num_relative_distance - 3
+        )
+        left = tf.ones((relative_position_index.shape[0], 1), dtype=relative_position_index.dtype) * (
+            self.num_relative_distance - 2
+        )
+        corner = tf.ones((1, 1), dtype=relative_position_index.dtype) * (self.num_relative_distance - 1)
+
+        left_corner = tf.concat([corner, left], axis=0)
+        relative_position_index = tf.concat([top, relative_position_index], axis=0)
+        relative_position_index = tf.concat([left_corner, relative_position_index], axis=1)  # [Wh*Ww + 1, Wh*Ww + 1]
+        return relative_position_index
+
+    def call(self, inputs=None) -> tf.Tensor:
+        relative_position_bias = tf.gather(self.relative_position_bias_table, self.relative_position_index, axis=0)
+        return tf.transpose(relative_position_bias, [2, 0, 1])
+
+
+class TFData2VecVisionEncoder(keras.layers.Layer):
+    def __init__(self, config: Data2VecVisionConfig, window_size: tuple | None = None, **kwargs):
+        super().__init__(**kwargs)
+        self.config = config
+        if config.use_shared_relative_position_bias:
+            self.relative_position_bias = TFData2VecVisionRelativePositionBias(
+                config, window_size=window_size, name="relative_position_bias"
+            )
+        else:
+            self.relative_position_bias = None
+
+        # stochastic depth decay rule
+        dpr = list(tf.linspace(0.0, config.drop_path_rate, config.num_hidden_layers))
+        self.layer = [
+            TFData2VecVisionLayer(
+                config,
+                window_size=window_size if config.use_relative_position_bias else None,
+                drop_path_rate=dpr[i],
+                name=f"layer_._{i}",
+            )
+            for i in range(config.num_hidden_layers)
+        ]
+
+    def call(
+        self,
+        hidden_states: tf.Tensor,
+        head_mask: tf.Tensor | None = None,
+        output_attentions: bool = False,
+        output_hidden_states: bool = False,
+        return_dict: bool = True,
+    ) -> tuple | TFBaseModelOutput:
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+
+        for i, layer_module in enumerate(self.layer):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+            # Passing `0.0` to the `relative_position_bias()` layer because otherwise Keras
+            # might complain about `Layer.call()` not being invoked properly. In this case this input
+            # i.e., 0.0 is not going to be used in any calculations so we're safe.
+            relative_position_bias = (
+                self.relative_position_bias(0.0) if self.relative_position_bias is not None else None
+            )
+            layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias)
+
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+        return TFBaseModelOutput(
+            last_hidden_state=hidden_states,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "relative_position_bias", None) is not None:
+            with tf.name_scope(self.relative_position_bias.name):
+                self.relative_position_bias.build(None)
+        if getattr(self, "layer", None) is not None:
+            for layer in self.layer:
+                with tf.name_scope(layer.name):
+                    layer.build(None)
+
+
+@keras_serializable
+class TFData2VecVisionMainLayer(keras.layers.Layer):
+    config_class = Data2VecVisionConfig
+
+    def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = True, **kwargs):
+        super().__init__(**kwargs)
+
+        self.config = config
+        self.add_pooling_layer = add_pooling_layer
+
+        self.embeddings = TFData2VecVisionEmbeddings(config, name="embeddings")
+        self.encoder = TFData2VecVisionEncoder(
+            config, window_size=self.embeddings.patch_embeddings.patch_shape, name="encoder"
+        )
+        self.layernorm = (
+            tf.identity
+            if config.use_mean_pooling
+            else keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+        )
+
+        # We are setting the `data_format` like so because from here on we will revert to the
+        # NCHW output format
+        self.pooler = TFData2VecVisionPooler(config, name="pooler") if add_pooling_layer else None
+
+    def get_input_embeddings(self) -> keras.layers.Layer:
+        return self.embeddings.patch_embeddings
+
+    def _prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        raise NotImplementedError
+
+    @unpack_inputs
+    def call(
+        self,
+        pixel_values: tf.Tensor | None = None,
+        bool_masked_pos: tf.Tensor | None = None,
+        head_mask: tf.Tensor | None = None,
+        output_attentions: bool | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        training: bool = False,
+    ) -> tuple | TFData2VecVisionModelOutputWithPooling:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        if head_mask is not None:
+            raise NotImplementedError
+        else:
+            head_mask = [None] * self.config.num_hidden_layers
+
+        embedding_output = self.embeddings(pixel_values, bool_masked_pos, training=training)
+
+        encoder_outputs = self.encoder(
+            embedding_output,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+
+        sequence_output = encoder_outputs[0]
+        sequence_output = self.layernorm(sequence_output)
+        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+        if not return_dict:
+            head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+            return head_outputs + encoder_outputs[1:]
+
+        return TFData2VecVisionModelOutputWithPooling(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "embeddings", None) is not None:
+            with tf.name_scope(self.embeddings.name):
+                self.embeddings.build(None)
+        if getattr(self, "encoder", None) is not None:
+            with tf.name_scope(self.encoder.name):
+                self.encoder.build(None)
+        if getattr(self, "layernorm", None) is not None:
+            if hasattr(self.layernorm, "name"):
+                with tf.name_scope(self.layernorm.name):
+                    self.layernorm.build((None, self.config.hidden_size))
+        if getattr(self, "pooler", None) is not None:
+            with tf.name_scope(self.pooler.name):
+                self.pooler.build(None)
+
+
+class TFData2VecVisionPooler(keras.layers.Layer):
+    def __init__(self, config: Data2VecVisionConfig, **kwargs):
+        super().__init__(**kwargs)
+        self.layernorm = (
+            keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+            if config.use_mean_pooling
+            else None
+        )
+        self.config = config
+
+    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+        if self.layernorm is not None:
+            # Mean pool the final hidden states of the patch tokens
+            patch_tokens = hidden_states[:, 1:, :]
+            pooled_output = self.layernorm(tf.reduce_mean(patch_tokens, axis=1))
+        else:
+            # Pool by simply taking the final hidden state of the [CLS] token
+            pooled_output = hidden_states[:, 0]
+
+        return pooled_output
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "layernorm", None) is not None:
+            if hasattr(self.layernorm, "name"):
+                with tf.name_scope(self.layernorm.name):
+                    self.layernorm.build((None, self.config.hidden_size))
+
+
+class TFData2VecVisionPreTrainedModel(TFPreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = Data2VecVisionConfig
+    base_model_prefix = "data2vec_vision"
+    main_input_name = "pixel_values"
+    _keys_to_ignore_on_load_unexpected = [r"relative_position_index"]
+
+
+DATA2VEC_VISION_START_DOCSTRING = r"""
+    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.).
+
+    This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+    as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+    behavior.
+
+    
+
+    TensorFlow models and layers in `transformers` accept two formats as input:
+
+    - having all inputs as keyword arguments (like PyTorch models), or
+    - having all inputs as a list, tuple or dict in the first positional argument.
+
+    The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+    and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+    pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+    format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+    the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+    positional argument:
+
+    - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
+    - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+    `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
+    - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+    `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
+
+    Note that when creating models and layers with
+    [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+    about any of this, as you can just pass inputs like you would to any other Python function!
+
+    
+
+    Args:
+        config ([`Data2VecVisionConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DATA2VEC_VISION_INPUTS_DOCSTRING = r"""
+    Args:
+        pixel_values (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` `dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
+            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+            [`BeitImageProcessor.__call__`] for details.
+
+        head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used
+            in eager mode, in graph mode the value will always be set to True.
+
+        training (`bool`, *optional*, defaults to `False``):
+            Whether or not to use the model in training mode (some modules like dropout modules have different
+            behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+    "The bare Data2VecVision Model transformer outputting raw hidden-states without any specific head on top.",
+    DATA2VEC_VISION_START_DOCSTRING,
+)
+class TFData2VecVisionModel(TFData2VecVisionPreTrainedModel):
+    def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = False, *inputs, **kwargs):
+        super().__init__(config, *inputs, **kwargs)
+        self.config = config
+
+        self.data2vec_vision = TFData2VecVisionMainLayer(
+            config, add_pooling_layer=add_pooling_layer, name="data2vec_vision"
+        )
+
+    def get_input_embeddings(self):
+        return self.data2vec_vision.get_input_embeddings()
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=TFData2VecVisionModelOutputWithPooling,
+        config_class=_CONFIG_FOR_DOC,
+        modality="vision",
+        expected_output=_EXPECTED_OUTPUT_SHAPE,
+    )
+    def call(
+        self,
+        pixel_values: TFModelInputType | None = None,
+        bool_masked_pos: tf.Tensor | None = None,
+        head_mask: np.ndarray | tf.Tensor | None = None,
+        output_attentions: bool | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        training: bool = False,
+    ) -> tuple | TFData2VecVisionModelOutputWithPooling:
+        r"""
+        bool_masked_pos (`tf.Tensor` of shape `(batch_size, num_patches)`, *optional*):
+            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+        """
+        outputs = self.data2vec_vision(
+            pixel_values=pixel_values,
+            bool_masked_pos=bool_masked_pos,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+
+        return outputs
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "data2vec_vision", None) is not None:
+            with tf.name_scope(self.data2vec_vision.name):
+                self.data2vec_vision.build(None)
+
+
+@add_start_docstrings(
+    """
+    Data2VecVision Model transformer with an image classification head on top (a linear layer on top of the average of
+    the final hidden states of the patch tokens) e.g. for ImageNet.
+    """,
+    DATA2VEC_VISION_START_DOCSTRING,
+)
+class TFData2VecVisionForImageClassification(TFData2VecVisionPreTrainedModel, TFSequenceClassificationLoss):
+    def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs):
+        super().__init__(config, *inputs, **kwargs)
+
+        self.num_labels = config.num_labels
+        self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=True, name="data2vec_vision")
+
+        # Classifier head
+        self.classifier = keras.layers.Dense(
+            units=config.num_labels,
+            kernel_initializer=get_initializer(config.initializer_range),
+            name="classifier",
+        )
+        self.config = config
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_IMAGE_CLASS_CHECKPOINT,
+        output_type=TFSequenceClassifierOutput,
+        config_class=_CONFIG_FOR_DOC,
+        expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+    )
+    def call(
+        self,
+        pixel_values: TFModelInputType | None = None,
+        head_mask: np.ndarray | tf.Tensor | None = None,
+        output_attentions: bool | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        labels: np.ndarray | tf.Tensor | None = None,
+        training: bool | None = False,
+    ) -> TFSequenceClassifierOutput | tuple:
+        r"""
+        labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
+            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.data2vec_vision(
+            pixel_values=pixel_values,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+
+        pooled_output = outputs.pooler_output if return_dict else outputs[1]
+        logits = self.classifier(pooled_output)
+        loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
+
+        if not return_dict:
+            output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return TFSequenceClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "data2vec_vision", None) is not None:
+            with tf.name_scope(self.data2vec_vision.name):
+                self.data2vec_vision.build(None)
+        if getattr(self, "classifier", None) is not None:
+            with tf.name_scope(self.classifier.name):
+                self.classifier.build([None, None, self.config.hidden_size])
+
+
+class TFData2VecVisionConvModule(keras.layers.Layer):
+    """
+    A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
+    layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+
+    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int | tuple[int, int],
+        padding: str = "valid",
+        bias: bool = False,
+        dilation: int | tuple[int, int] = 1,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.conv = keras.layers.Conv2D(
+            filters=out_channels,
+            kernel_size=kernel_size,
+            padding=padding,
+            use_bias=bias,
+            dilation_rate=dilation,
+            name="conv",
+        )
+        self.bn = keras.layers.BatchNormalization(name="bn", momentum=0.9, epsilon=1e-5)
+        self.activation = tf.nn.relu
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+
+    def call(self, input: tf.Tensor) -> tf.Tensor:
+        output = self.conv(input)
+        output = self.bn(output)
+        output = self.activation(output)
+        return output
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "conv", None) is not None:
+            with tf.name_scope(self.conv.name):
+                self.conv.build([None, None, None, self.in_channels])
+        if getattr(self, "bn", None) is not None:
+            with tf.name_scope(self.bn.name):
+                self.bn.build((None, None, None, self.out_channels))
+
+
+class TFAdaptiveAvgPool2D(keras.layers.Layer):
+    def __init__(self, output_dims: tuple[int, int], input_ordering: str = "NHWC", **kwargs):
+        super().__init__(**kwargs)
+        self.output_dims = output_dims
+        self.input_ordering = input_ordering
+        if input_ordering not in ("NCHW", "NHWC"):
+            raise ValueError("Unrecognized input_ordering, should be 'NCHW' or 'NHWC'!")
+        self.h_axis = input_ordering.index("H")
+        self.w_axis = input_ordering.index("W")
+
+    def pseudo_1d_pool(self, inputs: tf.Tensor, h_pooling: bool):
+        # Figure out which axis we're pooling on
+        if h_pooling:
+            axis = self.h_axis
+            output_dim = self.output_dims[0]
+        else:
+            axis = self.w_axis
+            output_dim = self.output_dims[1]
+        input_dim = inputs.shape[axis]
+
+        # Figure out the potential pooling windows
+        # This is the key idea - the torch op always uses only two
+        # consecutive pooling window sizes, like 3 and 4. Therefore,
+        # if we pool with both possible sizes, we simply need to gather
+        # the 'correct' pool at each position to reimplement the torch op.
+        small_window = math.ceil(input_dim / output_dim)
+        big_window = small_window + 1
+        if h_pooling:
+            output_dim = self.output_dims[0]
+            small_window_shape = (small_window, 1)
+            big_window_shape = (big_window, 1)
+        else:
+            output_dim = self.output_dims[1]
+            small_window_shape = (1, small_window)
+            big_window_shape = (1, big_window)
+
+        # For resizes to 1, or integer resizes, we can take quick shortcuts
+        if output_dim == input_dim:
+            return inputs
+        elif output_dim == 1:
+            return tf.reduce_mean(inputs, axis=axis, keepdims=True)
+        elif input_dim % output_dim == 0:
+            return tf.nn.avg_pool2d(
+                inputs,
+                ksize=small_window_shape,
+                strides=small_window_shape,
+                padding="VALID",
+                data_format=self.input_ordering,
+            )
+        # When upscaling by an integer factor we can also take a quick shortcut
+        elif output_dim > input_dim and output_dim % input_dim == 0:
+            return tf.repeat(inputs, repeats=output_dim // input_dim, axis=axis)
+
+        # For non-integer resizes, we pool with both possible window sizes and concatenate them
+        if output_dim < input_dim:
+            small_pool = tf.nn.avg_pool2d(
+                inputs, ksize=small_window_shape, strides=1, padding="VALID", data_format=self.input_ordering
+            )
+            big_pool = tf.nn.avg_pool2d(
+                inputs, ksize=big_window_shape, strides=1, padding="VALID", data_format=self.input_ordering
+            )
+            both_pool = tf.concat([small_pool, big_pool], axis=axis)
+        else:
+            # When we're actually upscaling instead, then we build the pools a bit differently
+            small_pool = inputs
+            big_pool = tf.nn.avg_pool2d(
+                inputs, ksize=big_window_shape, strides=1, padding="VALID", data_format=self.input_ordering
+            )
+            both_pool = tf.concat([small_pool, big_pool], axis=axis)
+
+        # We compute vectors of the start and end positions for each pooling window
+        # Each (start, end) pair here corresponds to a single output position
+        window_starts = tf.math.floor((tf.range(output_dim, dtype=tf.float32) * input_dim) / output_dim)
+        window_starts = tf.cast(window_starts, tf.int64)
+        window_ends = tf.math.ceil((tf.range(1, output_dim + 1, dtype=tf.float32) * input_dim) / output_dim)
+        window_ends = tf.cast(window_ends, tf.int64)
+
+        # pool_selector is a boolean array of shape (output_dim,) where 1 indicates that output position
+        # has a big receptive field and 0 indicates that that output position has a small receptive field
+        pool_selector = tf.cast(window_ends - window_starts - small_window, tf.bool)
+
+        # Since we concatenated the small and big pools, we need to do a bit of
+        # pointer arithmetic to get the indices of the big pools
+        small_indices = window_starts
+        big_indices = window_starts + small_pool.shape[axis]
+
+        # Finally, we use the pool_selector to generate a list of indices, one per output position
+        gather_indices = tf.where(pool_selector, big_indices, small_indices)
+
+        # Gathering from those indices yields the final, correct pooling
+        return tf.gather(both_pool, gather_indices, axis=axis)
+
+    def call(self, inputs: tf.Tensor):
+        if self.input_ordering == "NHWC":
+            input_shape = inputs.shape[1:3]
+        else:
+            input_shape = inputs.shape[2:]
+
+        # We break the task down into each possible case
+        # Firstly, if we're resizing down to 1, it's just tf.reduce_mean
+        if self.output_dims[0] == self.output_dims[1] == 1:
+            if self.input_ordering == "NHWC":
+                reduce_dims = [1, 2]
+            else:
+                reduce_dims = [2, 3]
+            return tf.reduce_mean(inputs, axis=reduce_dims, keepdims=True)
+        # Secondly, if we're resizing by an integer factor on both dimensions, we can take a quick shortcut
+        elif input_shape[0] % self.output_dims[0] == 0 and input_shape[1] % self.output_dims[1] == 0:
+            h_resize = int(input_shape[0] // self.output_dims[0])
+            w_resize = int(input_shape[1] // self.output_dims[1])
+            return tf.nn.avg_pool2d(
+                inputs,
+                ksize=(h_resize, w_resize),
+                strides=(h_resize, w_resize),
+                padding="VALID",
+                data_format=self.input_ordering,
+            )
+        else:
+            # Finally, if we can't take the shortcut, we do a 1D pool on each axis. pseudo_1d_pool will take a shortcut
+            # for dimensions where an integer resize is possible. It can also handle upscaling.
+            h_pooled = self.pseudo_1d_pool(inputs, h_pooling=True)
+            return self.pseudo_1d_pool(h_pooled, h_pooling=False)
+
+
+class TFData2VecVisionPyramidPoolingModule(keras.layers.Layer):
+    """
+    Pyramid Pooling Module (PPM) used in PSPNet.
+
+    Args:
+        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+            Module.
+        channels (int): Channels after modules, before conv_seg.
+
+    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+    """
+
+    def __init__(self, pool_scales: tuple[int, ...], in_channels: int, out_channels: int, **kwargs) -> None:
+        super().__init__(**kwargs)
+        self.pool_scales = pool_scales
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+
+        self.layer_list = []
+        for idx, pool_scale in enumerate(pool_scales):
+            pool_scale = pool_scale if isinstance(pool_scale, collections.abc.Iterable) else (pool_scale, pool_scale)
+            self.layer_list.append(
+                [
+                    TFAdaptiveAvgPool2D(output_dims=pool_scale),
+                    TFData2VecVisionConvModule(
+                        in_channels=in_channels, out_channels=self.out_channels, kernel_size=1, name=f"{idx}.1"
+                    ),
+                ]
+            )
+
+    def call(self, x: tf.Tensor) -> list[tf.Tensor]:
+        ppm_outs = []
+        inputs = x
+
+        for ppm in self.layer_list:
+            for layer_module in ppm:
+                ppm_out = layer_module(x)
+                x = ppm_out
+
+            upsampled_ppm_out = tf.image.resize(ppm_out, size=shape_list(inputs)[1:-1], method="bilinear")
+            ppm_outs.append(upsampled_ppm_out)
+        return ppm_outs
+
+    def build(self, input_shape=None):
+        for layer in self.layer_list:
+            for layer_module in layer:
+                with tf.name_scope(layer_module.name):
+                    layer_module.build(None)
+
+
+class TFData2VecVisionUperHead(keras.layers.Layer):
+    """
+    Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
+    [UPerNet](https://huggingface.co/papers/1807.10221).
+
+    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+    """
+
+    def __init__(self, config: Data2VecVisionConfig, **kwargs) -> None:
+        super().__init__(**kwargs)
+
+        self.pool_scales = config.pool_scales  # e.g. (1, 2, 3, 6)
+        self.in_channels = [config.hidden_size] * 4  # e.g. [768, 768, 768, 768]
+        self.channels = config.hidden_size
+        self.classifier = keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
+
+        # PSP Module
+        self.psp_modules = TFData2VecVisionPyramidPoolingModule(
+            self.pool_scales, self.in_channels[-1], self.channels, name="psp_modules"
+        )
+        self.bottleneck = TFData2VecVisionConvModule(
+            self.in_channels[-1] + len(self.pool_scales) * self.channels,
+            self.channels,
+            kernel_size=3,
+            padding="same",
+            name="bottleneck",
+        )
+        # FPN Module
+        self.lateral_convs = []
+        self.fpn_convs = []
+        for idx, in_channels in enumerate(self.in_channels[:-1]):  # skip the top layer
+            l_conv = TFData2VecVisionConvModule(
+                in_channels, out_channels=self.channels, kernel_size=1, name=f"lateral_convs.{idx}"
+            )
+            fpn_conv = TFData2VecVisionConvModule(
+                in_channels=self.channels,
+                out_channels=self.channels,
+                kernel_size=3,
+                padding="same",
+                name=f"fpn_convs.{idx}",
+            )
+            self.lateral_convs.append(l_conv)
+            self.fpn_convs.append(fpn_conv)
+
+        self.fpn_bottleneck = TFData2VecVisionConvModule(
+            in_channels=len(self.in_channels) * self.channels,
+            out_channels=self.channels,
+            kernel_size=3,
+            padding="same",
+            name="fpn_bottleneck",
+        )
+
+    def psp_forward(self, inputs):
+        x = inputs[-1]
+        psp_outs = [x]
+        psp_outs.extend(self.psp_modules(x))
+        psp_outs = tf.concat(psp_outs, axis=-1)
+        output = self.bottleneck(psp_outs)
+
+        return output
+
+    def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor:
+        # build laterals
+        laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
+
+        laterals.append(self.psp_forward(encoder_hidden_states))
+
+        # build top-down path
+        used_backbone_levels = len(laterals)
+        for i in range(used_backbone_levels - 1, 0, -1):
+            prev_shape = shape_list(laterals[i - 1])[1:-1]
+            laterals[i - 1] = laterals[i - 1] + tf.image.resize(laterals[i], size=prev_shape, method="bilinear")
+
+        # build outputs
+        fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
+        # append psp feature
+        fpn_outs.append(laterals[-1])
+
+        for i in range(used_backbone_levels - 1, 0, -1):
+            fpn_outs[i] = tf.image.resize(fpn_outs[i], size=shape_list(fpn_outs[0])[1:-1], method="bilinear")
+        fpn_outs = tf.concat(fpn_outs, axis=-1)
+        output = self.fpn_bottleneck(fpn_outs)
+        output = self.classifier(output)
+
+        return output
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "classifier", None) is not None:
+            with tf.name_scope(self.classifier.name):
+                self.classifier.build([None, None, None, self.channels])
+        if getattr(self, "psp_modules", None) is not None:
+            with tf.name_scope(self.psp_modules.name):
+                self.psp_modules.build(None)
+        if getattr(self, "bottleneck", None) is not None:
+            with tf.name_scope(self.bottleneck.name):
+                self.bottleneck.build(None)
+        if getattr(self, "fpn_bottleneck", None) is not None:
+            with tf.name_scope(self.fpn_bottleneck.name):
+                self.fpn_bottleneck.build(None)
+        for layer in self.lateral_convs:
+            with tf.name_scope(layer.name):
+                layer.build(None)
+        for layer in self.fpn_convs:
+            with tf.name_scope(layer.name):
+                layer.build(None)
+
+
+class TFData2VecVisionFCNHead(keras.layers.Layer):
+    """
+    Fully Convolution Networks for Semantic Segmentation. This head is implemented from
+    [FCNNet](https://huggingface.co/papers/1411.4038).
+
+    Args:
+        config (Data2VecVisionConfig): Configuration.
+        kernel_size (int): The kernel size for convs in the head. Default: 3.
+        dilation (int): The dilation rate for convs in the head. Default: 1.
+
+
+    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
+    """
+
+    def __init__(
+        self,
+        config: Data2VecVisionConfig,
+        in_index: int = 2,
+        kernel_size: int = 3,
+        dilation: int | tuple[int, int] = 1,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.in_channels = config.hidden_size
+        self.channels = config.auxiliary_channels
+        self.num_convs = config.auxiliary_num_convs
+        self.concat_input = config.auxiliary_concat_input
+        self.in_index = in_index
+
+        convs = []
+        convs.append(
+            TFData2VecVisionConvModule(
+                in_channels=self.in_channels,
+                out_channels=self.channels,
+                kernel_size=kernel_size,
+                padding="same",
+                dilation=dilation,
+                name="convs.0",
+            )
+        )
+        for i in range(self.num_convs - 1):
+            convs.append(
+                TFData2VecVisionConvModule(
+                    in_channels=self.channels,
+                    out_channels=self.channels,
+                    kernel_size=kernel_size,
+                    padding="same",
+                    dilation=dilation,
+                    name=f"conv_module_{i + 2}",
+                )
+            )
+        if self.num_convs == 0:
+            self.convs = [tf.identity]
+        else:
+            self.convs = convs
+        if self.concat_input:
+            self.conv_cat = TFData2VecVisionConvModule(
+                self.in_channels + self.channels,
+                out_channels=self.channels,
+                kernel_size=kernel_size,
+                padding="same",
+                name="conv_cat",
+            )
+
+        self.classifier = keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
+
+    def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor:
+        # just take the relevant feature maps
+        hidden_states = encoder_hidden_states[self.in_index]
+        output = hidden_states
+        for layer_module in self.convs:
+            output = layer_module(output)
+        if self.concat_input:
+            output = self.conv_cat(tf.concat([hidden_states, output], axis=-1))
+        output = self.classifier(output)
+        return output
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "classifier", None) is not None:
+            with tf.name_scope(self.classifier.name):
+                self.classifier.build([None, None, None, self.channels])
+        if getattr(self, "conv_cat", None) is not None:
+            with tf.name_scope(self.conv_cat.name):
+                self.conv_cat.build(None)
+
+
+@add_start_docstrings(
+    """
+    Data2VecVision Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
+    """,
+    DATA2VEC_VISION_START_DOCSTRING,
+)
+class TFData2VecVisionForSemanticSegmentation(TFData2VecVisionPreTrainedModel):
+    def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs) -> None:
+        super().__init__(config, *inputs, **kwargs)
+        self.num_labels = config.num_labels
+        self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=False, name="data2vec_vision")
+
+        # FPNs
+        self.fpn1 = [
+            keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.0"),
+            keras.layers.BatchNormalization(name="fpn1.1", momentum=0.9, epsilon=1e-5),
+            keras.layers.Activation("gelu"),
+            keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.3"),
+        ]
+        self.fpn2 = [keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn2.0")]
+
+        self.fpn3 = tf.identity
+        self.fpn4 = keras.layers.MaxPool2D(pool_size=2, strides=2)
+
+        # Semantic segmentation head(s)
+        self.decode_head = TFData2VecVisionUperHead(config, name="decode_head")
+        self.auxiliary_head = (
+            TFData2VecVisionFCNHead(config, name="auxiliary_head") if config.use_auxiliary_head else None
+        )
+
+    def compute_loss(self, logits, auxiliary_logits, labels):
+        # upsample logits to the images' original size
+        if len(shape_list(labels)) > 3:
+            label_interp_shape = shape_list(labels)[1:-1]
+        else:
+            label_interp_shape = shape_list(labels)[-2:]
+
+        upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
+        if auxiliary_logits is not None:
+            upsampled_auxiliary_logits = tf.image.resize(auxiliary_logits, size=label_interp_shape, method="bilinear")
+        # compute weighted loss
+        loss_fct = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
+
+        # Copied from https://www.tensorflow.org/text/tutorials/transformer#loss_and_metrics.
+        # Utility to mask the index to ignore during computing the loss.
+        def masked_loss(real, pred):
+            mask = tf.math.logical_not(tf.math.equal(real, self.config.semantic_loss_ignore_index))
+            loss_ = loss_fct(real, pred)
+            mask = tf.cast(mask, dtype=loss_.dtype)
+            loss_ *= mask
+            reduced_masked_loss = tf.reduce_sum(loss_) / tf.reduce_sum(mask)
+            return tf.reshape(reduced_masked_loss, (1,))
+
+        main_loss = masked_loss(labels, upsampled_logits)
+        auxiliary_loss = masked_loss(labels, upsampled_auxiliary_logits)
+        loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
+
+        return loss
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
+    def call(
+        self,
+        pixel_values: tf.Tensor | None = None,
+        head_mask: tf.Tensor | None = None,
+        labels: tf.Tensor | None = None,
+        output_attentions: bool | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+    ) -> tuple | TFSemanticSegmenterOutput:
+        r"""
+        labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*):
+            Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
+
+        Returns:
+
+        Examples:
+
+        ```python
+        >>> from transformers import AutoImageProcessor, TFData2VecVisionForSemanticSegmentation
+        >>> from PIL import Image
+        >>> import requests
+
+        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw)
+
+        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/data2vec-vision-base")
+        >>> model = TFData2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base")
+
+        >>> inputs = image_processor(images=image, return_tensors="pt")
+        >>> outputs = model(**inputs)
+        >>> # logits are of shape (batch_size, num_labels, height, width)
+        >>> logits = outputs.logits
+        ```"""
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+
+        if labels is not None and self.config.num_labels == 1:
+            raise ValueError("The number of labels should be greater than one")
+
+        outputs = self.data2vec_vision(
+            pixel_values,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=True,  # we need the intermediate hidden states
+            return_dict=return_dict,
+        )
+        encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+        # only keep certain features, and reshape
+        # note that we do +1 as the encoder_hidden_states also includes the initial embeddings
+        features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
+        patch_resolution = self.config.image_size // self.config.patch_size
+
+        def reshape_features(x):
+            # We do it this way so TF can always infer the non-batch dims at compile time
+            x = tf.reshape(x, (-1, patch_resolution, patch_resolution, self.config.hidden_size))
+            return x
+
+        features = [reshape_features(x[:, 1:, :]) for x in features]
+
+        # apply FPNs
+        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
+        for module in ops[0]:
+            features[0] = module(features[0])
+        features[1] = ops[1][0](features[1])
+        for i in range(len(features[2:])):
+            features[i + 2] = ops[i + 2](features[i + 2])
+
+        logits = self.decode_head(features)
+        # Transpose the logits to maintain consistency in the output formats.
+        transposed_logits = tf.transpose(logits, perm=[0, 3, 1, 2])
+
+        auxiliary_logits = None
+        if self.auxiliary_head is not None:
+            auxiliary_logits = self.auxiliary_head(features)
+
+        loss = None
+        if labels is not None:
+            loss = self.compute_loss(logits, auxiliary_logits, labels)
+
+        if not return_dict:
+            if output_hidden_states:
+                output = (logits,) + outputs[1:]
+            else:
+                output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return TFSemanticSegmenterOutput(
+            loss=loss,
+            logits=transposed_logits,
+            hidden_states=outputs.hidden_states if output_hidden_states else None,
+            attentions=outputs.attentions,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "data2vec_vision", None) is not None:
+            with tf.name_scope(self.data2vec_vision.name):
+                self.data2vec_vision.build(None)
+        if getattr(self, "decode_head", None) is not None:
+            with tf.name_scope(self.decode_head.name):
+                self.decode_head.build(None)
+        if getattr(self, "auxiliary_head", None) is not None:
+            with tf.name_scope(self.auxiliary_head.name):
+                self.auxiliary_head.build(None)
+        if getattr(self, "fpn1", None) is not None:
+            with tf.name_scope(self.fpn1[0].name):
+                self.fpn1[0].build([None, None, None, self.config.hidden_size])
+            with tf.name_scope(self.fpn1[1].name):
+                self.fpn1[1].build((None, None, None, self.config.hidden_size))
+            with tf.name_scope(self.fpn1[3].name):
+                self.fpn1[3].build([None, None, None, self.config.hidden_size])
+        if getattr(self, "fpn2", None) is not None:
+            with tf.name_scope(self.fpn2[0].name):
+                self.fpn2[0].build([None, None, None, self.config.hidden_size])
+
+
+__all__ = [
+    "TFData2VecVisionForImageClassification",
+    "TFData2VecVisionForSemanticSegmentation",
+    "TFData2VecVisionModel",
+    "TFData2VecVisionPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modular_data2vec_audio.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modular_data2vec_audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..91cb04730e4aec01edff0f35701eb12c3ee5af3d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/data2vec/modular_data2vec_audio.py
@@ -0,0 +1,267 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Data2VecText model."""
+
+import math
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import Wav2Vec2BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ..wav2vec2.modeling_wav2vec2 import (
+    Wav2Vec2Adapter,
+    Wav2Vec2Encoder,
+    Wav2Vec2FeatureEncoder,
+    Wav2Vec2FeatureProjection,
+    Wav2Vec2ForAudioFrameClassification,
+    Wav2Vec2ForCTC,
+    Wav2Vec2ForSequenceClassification,
+    Wav2Vec2ForXVector,
+    Wav2Vec2Model,
+    Wav2Vec2PreTrainedModel,
+    Wav2Vec2SamePadLayer,
+)
+from .configuration_data2vec_audio import Data2VecAudioConfig
+
+
+class Data2VecAudioConvLayer(GradientCheckpointingLayer):
+    def __init__(self, config, layer_id=0):
+        super().__init__()
+        self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+        self.out_conv_dim = config.conv_dim[layer_id]
+
+        self.conv = nn.Conv1d(
+            self.in_conv_dim,
+            self.out_conv_dim,
+            kernel_size=config.conv_kernel[layer_id],
+            stride=config.conv_stride[layer_id],
+            bias=config.conv_bias,
+        )
+        self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+        self.activation = ACT2FN[config.feat_extract_activation]
+
+    def forward(self, hidden_states):
+        hidden_states = self.conv(hidden_states)
+
+        hidden_states = hidden_states.transpose(-2, -1)
+        hidden_states = self.layer_norm(hidden_states)
+        hidden_states = hidden_states.transpose(-2, -1)
+
+        hidden_states = self.activation(hidden_states)
+        return hidden_states
+
+
+class Data2VecAudioPadLayer(Wav2Vec2SamePadLayer):
+    pass
+
+
+class Data2VecAudioPositionalConvLayer(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.conv = nn.Conv1d(
+            config.hidden_size,
+            config.hidden_size,
+            kernel_size=config.conv_pos_kernel_size,
+            padding=config.conv_pos_kernel_size // 2,
+            groups=config.num_conv_pos_embedding_groups,
+        )
+
+        self.padding = Data2VecAudioPadLayer(config.conv_pos_kernel_size)
+        self.activation = ACT2FN[config.feat_extract_activation]
+        # no learnable parameters
+        self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
+
+    def forward(self, hidden_states):
+        hidden_states = self.conv(hidden_states)
+        hidden_states = self.padding(hidden_states)
+
+        hidden_states = hidden_states.transpose(1, 2)
+        hidden_states = self.layer_norm(hidden_states)
+        hidden_states = hidden_states.transpose(1, 2)
+        hidden_states = self.activation(hidden_states)
+        return hidden_states
+
+
+class Data2VecAudioPositionalConvEmbedding(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [Data2VecAudioPositionalConvLayer(config) for _ in range(config.num_conv_pos_embeddings)]
+        )
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.transpose(1, 2)
+        for layer in self.layers:
+            hidden_states = layer(hidden_states)
+        hidden_states = hidden_states.transpose(1, 2)
+        return hidden_states
+
+
+class Data2VecAudioFeatureEncoder(Wav2Vec2FeatureEncoder):
+    def __init__(self, config):
+        nn.Module.__init__(self)
+        self.conv_layers = nn.ModuleList(
+            [Data2VecAudioConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
+        )
+        self.gradient_checkpointing = False
+        self._requires_grad = True
+
+
+class Data2VecAudioFeatureProjection(Wav2Vec2FeatureProjection):
+    pass
+
+
+class Data2VecAudioEncoder(Wav2Vec2Encoder):
+    pass
+
+
+class Data2VecAudioAdapter(Wav2Vec2Adapter):
+    pass
+
+
+class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
+    config: Data2VecAudioConfig
+    base_model_prefix = "data2vec_audio"
+    main_input_name = "input_values"
+    supports_gradient_checkpointing = True
+    _supports_flash_attn = True
+    _supports_sdpa = True
+    _supports_flex_attn = True
+
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        if isinstance(module, Data2VecAudioFeatureProjection):
+            k = math.sqrt(1 / module.projection.in_features)
+            nn.init.uniform_(module.projection.weight, a=-k, b=k)
+            nn.init.uniform_(module.projection.bias, a=-k, b=k)
+        elif isinstance(module, Data2VecAudioPositionalConvLayer):
+            nn.init.constant_(module.conv.bias, 0)
+        elif isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+            if module.bias is not None:
+                module.bias.data.zero_()
+            if module.weight is not None:
+                module.weight.data.fill_(1.0)
+        elif isinstance(module, nn.Conv1d):
+            nn.init.kaiming_normal_(module.weight)
+
+            if module.bias is not None:
+                k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+                nn.init.uniform_(module.bias, a=-k, b=k)
+
+    def _get_adapters(self):
+        raise AttributeError("Not needed for Data2VecAudio")
+
+    def init_adapter_layers(self):
+        raise AttributeError("Not needed for Data2VecAudio")
+
+    def load_adapter(self):
+        raise AttributeError("Not needed for Data2VecAudio")
+
+
+Data2VecAudioBaseModelOutput = Wav2Vec2BaseModelOutput
+
+
+class Data2VecAudioModel(Data2VecAudioPreTrainedModel, Wav2Vec2Model):
+    def __init__(self, config: Data2VecAudioConfig):
+        Data2VecAudioPreTrainedModel.__init__(self, config)
+        self.config = config
+        self.feature_extractor = Data2VecAudioFeatureEncoder(config)
+        self.feature_projection = Data2VecAudioFeatureProjection(config)
+
+        # model only needs masking vector if mask prob is > 0.0
+        if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+            self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
+
+        self.encoder = Data2VecAudioEncoder(config)
+
+        self.adapter = Data2VecAudioAdapter(config) if config.add_adapter else None
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def freeze_feature_extractor(self):
+        raise AttributeError("Not needed for Data2VecAudio")
+
+    def freeze_feature_encoder(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+        not be updated during training.
+        """
+        self.feature_extractor._freeze_parameters()
+
+    def forward(self, **super_kwargs):
+        return super().forward(**super_kwargs)
+
+
+class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel, Wav2Vec2ForCTC):
+    def __init__(self, config):
+        Data2VecAudioPreTrainedModel.__init__(self, config)
+
+        self.data2vec_audio = Data2VecAudioModel(config)
+        self.dropout = nn.Dropout(config.final_dropout)
+
+        if config.vocab_size is None:
+            raise ValueError(
+                f"You are trying to instantiate {self.__class__} with a configuration that "
+                "does not define the vocabulary size of the language model head. Please "
+                "instantiate the model as follows: `Data2VecAudioForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
+                "or define `vocab_size` of your model's configuration."
+            )
+        output_hidden_size = (
+            config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
+        )
+        self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def freeze_base_model(self):
+        raise AttributeError("Not needed for Data2VecAudio")
+
+    def tie_weights(self):
+        raise AttributeError("Not needed for Data2VecAudio")
+
+    def forward(self, **super_kwargs):
+        return super().forward(**super_kwargs)
+
+
+class Data2VecAudioForSequenceClassification(Wav2Vec2ForSequenceClassification):
+    pass
+
+
+class Data2VecAudioForAudioFrameClassification(Wav2Vec2ForAudioFrameClassification):
+    pass
+
+
+class Data2VecAudioForXVector(Wav2Vec2ForXVector):
+    pass
+
+
+__all__ = [
+    "Data2VecAudioForAudioFrameClassification",
+    "Data2VecAudioForCTC",
+    "Data2VecAudioForSequenceClassification",
+    "Data2VecAudioForXVector",
+    "Data2VecAudioModel",
+    "Data2VecAudioPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cce0f34c778d21a81d03940e7a7951707d898c86
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_dbrx import *
+    from .modeling_dbrx import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/configuration_dbrx.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/configuration_dbrx.py
new file mode 100644
index 0000000000000000000000000000000000000000..17b6b2a368cc5f13783ae7b333dc7af78aee9d05
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/configuration_dbrx.py
@@ -0,0 +1,232 @@
+# coding=utf-8
+# Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DBRX model configuration"""
+
+from typing import Any, Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DbrxAttentionConfig(PretrainedConfig):
+    """Configuration class for Dbrx Attention.
+
+    [`DbrxAttention`] class. It is used to instantiate attention layers
+    according to the specified arguments, defining the layers architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        attn_pdrop (`float`, *optional*, defaults to 0.0):
+            The dropout probability for the attention layers.
+        clip_qkv (`float`, *optional*):
+            If set, clip the queries, keys, and values in the attention layer to this value.
+        kv_n_heads (`int`, *optional*, defaults to 1): For grouped_query_attention only, allow user to specify number of kv heads.
+        rope_theta (`float`, *optional*, defaults to 10000.0): The base frequency for rope.
+    """
+
+    base_config_key = "attn_config"
+
+    def __init__(
+        self,
+        attn_pdrop: float = 0.0,
+        clip_qkv: Optional[float] = None,
+        kv_n_heads: int = 1,
+        rope_theta: float = 10000.0,
+        **kwargs: Any,
+    ):
+        super().__init__(**kwargs)
+        self.attn_pdrop = attn_pdrop
+        self.clip_qkv = clip_qkv
+        self.kv_n_heads = kv_n_heads
+        self.rope_theta = rope_theta
+
+        for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype", "dtype"]:
+            if k in kwargs:
+                kwargs.pop(k)
+        if len(kwargs) != 0:
+            raise ValueError(f"Found unknown {kwargs=}")
+
+
+class DbrxFFNConfig(PretrainedConfig):
+    """Configuration class for Dbrx FFN.
+
+    [`DbrxFFN`] class. It is used to instantiate feedforward layers according to
+    the specified arguments, defining the layers architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        ffn_act_fn (`dict`, *optional*, defaults to `None`): A dict specifying activation function for the FFN.
+            The dict should have a key 'name' with the value being the name of the activation function along with
+            any additional keyword arguments. If `None`, then set to `{"name": "silu"}`.
+        ffn_hidden_size (`int`, *optional*, defaults to 3584): The hidden size of the feedforward network.
+        moe_num_experts (`int`, *optional*, defaults to 4): The number of experts in the mixture of experts layer.
+        moe_top_k (`int`, *optional*, defaults to 1): The number of experts to use in the mixture of experts layer.
+        moe_jitter_eps (`float`, *optional*, defaults to `None`): If not `None`, the jitter epsilon for the mixture of experts layer.
+        moe_loss_weight (`float`, *optional*, defaults to 0.01): The loss weight for the mixture of experts layer.
+        moe_normalize_expert_weights (`float`, *optional*, defaults to 1.0): The normalization factor for the expert weights.
+    """
+
+    base_config_key = "ffn_config"
+
+    def __init__(
+        self,
+        ffn_act_fn: Optional[dict] = None,
+        ffn_hidden_size: int = 3584,
+        moe_num_experts: int = 4,
+        moe_top_k: int = 1,
+        moe_jitter_eps: Optional[float] = None,
+        moe_loss_weight: float = 0.01,
+        moe_normalize_expert_weights: Optional[float] = 1.0,
+        **kwargs: Any,
+    ):
+        super().__init__()
+        if ffn_act_fn is None:
+            ffn_act_fn = {"name": "silu"}
+        self.ffn_act_fn = ffn_act_fn
+        self.ffn_hidden_size = ffn_hidden_size
+        self.moe_num_experts = moe_num_experts
+        self.moe_top_k = moe_top_k
+        self.moe_jitter_eps = moe_jitter_eps
+        self.moe_loss_weight = moe_loss_weight
+        self.moe_normalize_expert_weights = moe_normalize_expert_weights
+
+        for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype", "dtype"]:
+            if k in kwargs:
+                kwargs.pop(k)
+        if len(kwargs) != 0:
+            raise ValueError(f"Found unknown {kwargs=}")
+
+
+class DbrxConfig(PretrainedConfig):
+    r"""
+
+    This is the configuration class to store the configuration of a [`DbrxModel`]. It is used to instantiate a Dbrx model according to the
+    specified arguments, defining the model architecture. Instantiating a configuration with the
+    defaults will yield a different configuration to that of the [databricks/dbrx-instruct](https://huggingface.co/databricks/dbrx-instruct) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        d_model (`int`, *optional*, defaults to 2048):
+            Dimensionality of the embeddings and hidden states.
+        n_heads (`int`, *optional*, defaults to 16):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        n_layers (`int`, *optional*, defaults to 24):
+            Number of hidden layers in the Transformer encoder.
+        max_seq_len (`int`, *optional*, defaults to 2048):
+            The maximum sequence length of the model.
+        vocab_size (`int`, *optional*, defaults to 32000):
+            Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
+            the `inputs_ids` passed when calling [`DbrxModel`].
+        resid_pdrop (`float`, *optional*, defaults to 0.0):
+            The dropout probability applied to the attention output before combining with residual.
+        emb_pdrop (`float`, *optional*, defaults to 0.0):
+            The dropout probability for the embedding layer.
+        attn_config (`dict`, *optional*):
+            A dictionary used to configure the model's attention module.
+        ffn_config (`dict`, *optional*):
+            A dictionary used to configure the model's FFN module.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models).
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        output_router_logits (`bool`, *optional*, defaults to `False`):
+            Whether or not the router logits should be returned by the model. Enabling this will also
+            allow the model to output the auxiliary loss. See [here]() for more details.
+
+
+    Example:
+    ```python
+    >>> from transformers import DbrxConfig, DbrxModel
+
+    >>> # Initializing a Dbrx configuration
+    >>> configuration = DbrxConfig(n_layers=2, d_model=256, n_heads=8, vocab_size=128)
+
+    >>> # Initializing a model (with random weights) from the configuration
+    >>> model = DbrxModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```
+    """
+
+    model_type = "dbrx"
+    sub_configs = {"attn_config": DbrxAttentionConfig, "ffn_config": DbrxFFNConfig}
+    attribute_map = {
+        "num_attention_heads": "n_heads",
+        "hidden_size": "d_model",
+        "num_hidden_layers": "n_layers",
+        "max_position_embeddings": "max_seq_len",
+    }
+
+    def __init__(
+        self,
+        d_model: int = 2048,
+        n_heads: int = 16,
+        n_layers: int = 24,
+        max_seq_len: int = 2048,
+        vocab_size: int = 32000,
+        resid_pdrop: float = 0.0,
+        emb_pdrop: float = 0.0,
+        attn_config: Optional[DbrxAttentionConfig] = None,
+        ffn_config: Optional[DbrxFFNConfig] = None,
+        use_cache: bool = True,
+        initializer_range: float = 0.02,
+        output_router_logits: bool = False,
+        **kwargs: Any,
+    ):
+        if attn_config is None:
+            self.attn_config = DbrxAttentionConfig()
+        elif isinstance(attn_config, dict):
+            self.attn_config = DbrxAttentionConfig(**attn_config)
+        else:
+            self.attn_config = attn_config
+
+        if ffn_config is None:
+            self.ffn_config = DbrxFFNConfig()
+        elif isinstance(ffn_config, dict):
+            self.ffn_config = DbrxFFNConfig(**ffn_config)
+        else:
+            self.ffn_config = ffn_config
+
+        self.d_model = d_model
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.max_seq_len = max_seq_len
+        self.vocab_size = vocab_size
+        self.resid_pdrop = resid_pdrop
+        self.emb_pdrop = emb_pdrop
+        self.use_cache = use_cache
+        self.initializer_range = initializer_range
+        self.output_router_logits = output_router_logits
+        self.num_key_value_heads = self.attn_config.kv_n_heads
+
+        tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
+        if tie_word_embeddings:
+            raise ValueError("tie_word_embeddings is not supported for DBRX models.")
+
+        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+__all__ = ["DbrxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/modeling_dbrx.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/modeling_dbrx.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f3a423213cba3ef11b95a36a3efa5aa01da69c2
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dbrx/modeling_dbrx.py
@@ -0,0 +1,1248 @@
+# coding=utf-8
+# Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DBRX model."""
+
+import math
+from typing import Any, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, is_torch_flex_attn_available, logging
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_dbrx import DbrxConfig
+
+
+if is_torch_flex_attn_available():
+    from torch.nn.attention.flex_attention import BlockMask
+
+    from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+if is_flash_attn_available():
+    from ...modeling_flash_attention_utils import _flash_attention_forward
+
+logger = logging.get_logger(__name__)
+
+
+class DbrxRotaryEmbedding(nn.Module):
+    inv_freq: torch.Tensor  # fix linting for `register_buffer`
+
+    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+        super().__init__()
+
+        self.dim = dim
+        self.max_position_embeddings = max_position_embeddings
+        self.base = base
+
+        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
+        self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
+
+    @torch.no_grad()
+    def forward(self, x, position_ids, seq_len=None):
+        # x: [bs, num_attention_heads, seq_len, head_size]
+        self.inv_freq.to(x.device)
+        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+        position_ids_expanded = position_ids[:, None, :].float()
+        # Force float32 since bfloat16 loses precision on long contexts
+        # See https://github.com/huggingface/transformers/pull/29285
+        device_type = x.device.type
+        device_type = device_type if device_type != "mps" else "cpu"
+        with torch.autocast(device_type=device_type, enabled=False):
+            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+            emb = torch.cat((freqs, freqs), dim=-1)
+            cos = emb.cos()
+            sin = emb.sin()
+        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+    """Rotates half the hidden dims of the input."""
+    x1 = x[..., : x.shape[-1] // 2]
+    x2 = x[..., x.shape[-1] // 2 :]
+    return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+    """Applies Rotary Position Embedding to the query and key tensors.
+
+    Args:
+        q (`torch.Tensor`): The query tensor.
+        k (`torch.Tensor`): The key tensor.
+        cos (`torch.Tensor`): The cosine part of the rotary embedding.
+        sin (`torch.Tensor`): The sine part of the rotary embedding.
+        position_ids (`torch.Tensor`, *optional*):
+            Deprecated and unused.
+        unsqueeze_dim (`int`, *optional*, defaults to 1):
+            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+    Returns:
+        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+    """
+    cos = cos.unsqueeze(unsqueeze_dim)
+    sin = sin.unsqueeze(unsqueeze_dim)
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+    """
+    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+    """
+    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+    if n_rep == 1:
+        return hidden_states
+    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def load_balancing_loss_func(
+    gate_probabilities: torch.Tensor,
+    num_experts: int,
+    top_k: int,
+    attention_mask: Optional[torch.Tensor],
+) -> torch.Tensor:
+    r"""Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+    See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
+    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+    experts is too unbalanced.
+
+    Args:
+        gate_logits (Union[`torch.Tensor`, tuple[torch.Tensor]):
+            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
+            shape [batch_size X sequence_length, num_experts].
+        num_experts (`int`):
+            Number of experts.
+        top_k (`int`):
+            The number of experts each token is routed to.
+        attention_mask (`torch.Tensor`, *optional*):
+            The attention_mask used in forward function
+            shape [batch_size X sequence_length] if not None.
+
+    Returns:
+        The auxiliary loss.
+    """
+    if gate_probabilities is None or not isinstance(gate_probabilities, tuple):
+        return torch.tensor(0.0)
+
+    if isinstance(gate_probabilities, tuple):
+        compute_device = gate_probabilities[0].device
+        routing_weights = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_probabilities], dim=0)
+
+    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+
+    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
+
+    if attention_mask is None:
+        # Compute the percentage of tokens routed to each experts
+        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
+
+        # Compute the average probability of routing to these experts
+        router_prob_per_expert = torch.mean(routing_weights, dim=0)
+    else:
+        batch_size, sequence_length = attention_mask.shape
+        num_hidden_layers = routing_weights.shape[0] // (batch_size * sequence_length)
+
+        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
+        expert_attention_mask = (
+            attention_mask[None, :, :, None, None]
+            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
+            .reshape(-1, top_k, num_experts)
+            .to(compute_device)
+        )
+
+        # Compute the percentage of tokens routed to each experts
+        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
+            expert_attention_mask, dim=0
+        )
+
+        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
+        router_per_expert_attention_mask = (
+            attention_mask[None, :, :, None]
+            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
+            .reshape(-1, num_experts)
+            .to(compute_device)
+        )
+
+        # Compute the average probability of routing to these experts
+        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
+            router_per_expert_attention_mask, dim=0
+        )
+
+    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+    return overall_loss * num_experts
+
+
+class DbrxAttention(nn.Module):
+    """Multi-head self attention."""
+
+    def __init__(self, config: DbrxConfig, block_idx: Optional[int] = None):
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.d_model
+        self.num_heads = config.n_heads
+        self.head_dim = self.hidden_size // self.num_heads
+        self.max_position_embeddings = config.max_seq_len
+        self.block_idx = block_idx
+        if block_idx is None:
+            logger.warning_once(
+                f"Instantiating {self.__class__.__name__} without passing a `block_idx` is not recommended and will "
+                + "lead to errors during the forward call if caching is used. Please make sure to provide a `block_idx` "
+                + "when creating this class."
+            )
+
+        attn_config = config.attn_config
+        self.attn_pdrop = attn_config.attn_pdrop
+        self.clip_qkv = attn_config.clip_qkv
+        self.num_key_value_heads = attn_config.kv_n_heads
+        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+        self.rope_theta = attn_config.rope_theta
+        self.is_causal = True
+
+        self.Wqkv = nn.Linear(
+            self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False
+        )
+        self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
+        self.rotary_emb = DbrxRotaryEmbedding(
+            self.head_dim,
+            max_position_embeddings=self.max_position_embeddings,
+            base=self.rope_theta,
+        )
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_ids: torch.LongTensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Cache] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs: Any,
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
+        bsz, q_len, _ = hidden_states.size()
+
+        qkv_states = self.Wqkv(hidden_states)
+        min_val = -self.clip_qkv if self.clip_qkv is not None else None
+        max_val = self.clip_qkv
+        qkv_states = qkv_states.clamp(min=min_val, max=max_val)
+
+        query_states, key_states, value_states = qkv_states.split(
+            [
+                self.hidden_size,
+                self.num_key_value_heads * self.head_dim,
+                self.num_key_value_heads * self.head_dim,
+            ],
+            dim=2,
+        )
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        cos, sin = self.rotary_emb(value_states, position_ids)
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_values is not None:
+            # sin and cos are specific to RoPE models; position_ids needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_values.update(key_states, value_states, self.block_idx, cache_kwargs)
+
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attention_mask is not None:  # no matter the length, we just slice it
+            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+            attn_weights = attn_weights + causal_mask
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_weights = nn.functional.dropout(attn_weights, p=self.attn_pdrop, training=self.training)
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+                + f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+        attn_output = self.out_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights
+
+
+class DbrxFlashAttention2(DbrxAttention):
+    """Dbrx flash attention module.
+
+    This module inherits from `DbrxAttention` as the weights of the module stays
+    untouched. The only required change would be on the forward pass where it
+    calls the public API of flash attention.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+        self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs: Any,
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+        if isinstance(past_key_values, StaticCache):
+            raise ValueError(
+                "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+                "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+            )
+        logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.")
+        output_attentions = False
+
+        bsz, q_len, _ = hidden_states.size()
+
+        qkv_states = self.Wqkv(hidden_states)
+        if self.clip_qkv is not None:
+            qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
+
+        query_states, key_states, value_states = qkv_states.split(
+            [
+                self.hidden_size,
+                self.num_key_value_heads * self.head_dim,
+                self.num_key_value_heads * self.head_dim,
+            ],
+            dim=2,
+        )
+
+        # Flash attention requires the input to have the shape
+        # batch_size x seq_length x head_dim x hidden_dim
+        # therefore we just need to keep the original shape
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        cos, sin = self.rotary_emb(value_states, position_ids)
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_values is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_values.update(key_states, value_states, self.block_idx, cache_kwargs)
+
+        # TODO: These transpose are quite inefficient but Flash Attention requires the layout
+        # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+        # to be able to avoid many of these transpose/reshape/view.
+        query_states = query_states.transpose(1, 2)
+        key_states = key_states.transpose(1, 2)
+        value_states = value_states.transpose(1, 2)
+
+        dropout_rate = self.attn_pdrop if self.training else 0.0
+
+        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+        # therefore the input hidden states gets silently casted in float32. Hence, we need
+        # cast them back in the correct dtype just to be sure everything works as expected.
+        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+        # in fp32. (LlamaRMSNorm handles it correctly)
+        input_dtype = query_states.dtype
+        device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
+        if input_dtype == torch.float32:
+            if torch.is_autocast_enabled():
+                target_dtype = (
+                    torch.get_autocast_dtype(device_type)
+                    if hasattr(torch, "get_autocast_dtype")
+                    else torch.get_autocast_gpu_dtype()
+                )
+            # Handle the case where the model is quantized
+            elif hasattr(self.config, "_pre_quantization_dtype"):
+                target_dtype = self.config._pre_quantization_dtype
+            else:
+                target_dtype = query_states.dtype
+
+            logger.warning_once(
+                "The input hidden states seems to be silently casted in float32, this might be "
+                + "related to the fact you have upcasted embedding or layer norm layers in "
+                + f"float32. We will cast back the input in {target_dtype}."
+            )
+
+            query_states = query_states.to(target_dtype)
+            key_states = key_states.to(target_dtype)
+            value_states = value_states.to(target_dtype)
+
+        attn_output = _flash_attention_forward(
+            query_states,
+            key_states,
+            value_states,
+            attention_mask,
+            q_len,
+            position_ids=position_ids,
+            dropout=dropout_rate,
+            is_causal=self.is_causal,
+            use_top_left_mask=self._flash_attn_uses_top_left_mask,
+        )
+
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+        attn_output = self.out_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights
+
+
+class DbrxSdpaAttention(DbrxAttention):
+    """
+    Dbrx attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+    `DbrxAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+    SDPA API.
+    """
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+        if output_attentions:
+            # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+            logger.warning_once(
+                "DbrxModel is using DbrxSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+            )
+            return super().forward(
+                hidden_states=hidden_states,
+                attention_mask=attention_mask,
+                position_ids=position_ids,
+                past_key_values=past_key_values,
+                output_attentions=output_attentions,
+                use_cache=use_cache,
+                cache_position=cache_position,
+            )
+
+        bsz, q_len, _ = hidden_states.size()
+
+        qkv_states = self.Wqkv(hidden_states)
+        if self.clip_qkv is not None:
+            qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
+
+        query_states, key_states, value_states = qkv_states.split(
+            [
+                self.hidden_size,
+                self.num_key_value_heads * self.head_dim,
+                self.num_key_value_heads * self.head_dim,
+            ],
+            dim=2,
+        )
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
+
+        if past_key_values is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_values.update(key_states, value_states, self.block_idx, cache_kwargs)
+
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+        causal_mask = attention_mask
+        if attention_mask is not None:
+            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+        # Reference: https://github.com/pytorch/pytorch/issues/112577.
+        if query_states.device.type == "cuda" and causal_mask is not None:
+            query_states = query_states.contiguous()
+            key_states = key_states.contiguous()
+            value_states = value_states.contiguous()
+
+        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+        is_causal = causal_mask is None and q_len > 1
+
+        attn_output = torch.nn.functional.scaled_dot_product_attention(
+            query_states,
+            key_states,
+            value_states,
+            attn_mask=causal_mask,
+            dropout_p=self.attn_pdrop if self.training else 0.0,
+            is_causal=is_causal,
+        )
+
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.view(bsz, q_len, -1)
+
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, None
+
+
+DBRX_ATTENTION_CLASSES = {
+    "eager": DbrxAttention,
+    "flash_attention_2": DbrxFlashAttention2,
+    "sdpa": DbrxSdpaAttention,
+}
+
+
+class DbrxNormAttentionNorm(nn.Module):
+    def __init__(self, config: DbrxConfig, block_idx: Optional[int] = None):
+        super().__init__()
+        self.block_idx = block_idx
+        self.resid_pdrop = config.resid_pdrop
+        self.norm_1 = nn.LayerNorm(config.d_model, bias=False)
+        self.attn = DBRX_ATTENTION_CLASSES[config._attn_implementation](
+            config=config,
+            block_idx=block_idx,
+        )
+        self.norm_2 = nn.LayerNorm(config.d_model, bias=False)
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_ids: torch.LongTensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Cache] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs: Any,
+    ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
+        residual_states = hidden_states
+        hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype)
+
+        hidden_states, attn_weights = self.attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+            cache_position=cache_position,
+            **kwargs,
+        )
+
+        hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
+        hidden_states = hidden_states + residual_states
+
+        residual_states = hidden_states
+        hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype)
+
+        return residual_states, hidden_states, attn_weights
+
+
+class DbrxRouter(nn.Module):
+    def __init__(
+        self,
+        hidden_size: int,
+        moe_num_experts: int,
+        moe_top_k: int,
+        moe_jitter_eps: Optional[float],
+        moe_normalize_expert_weights: Optional[float],
+    ):
+        super().__init__()
+        self.hidden_size = hidden_size
+        self.moe_num_experts = moe_num_experts
+        self.moe_top_k = moe_top_k
+        self.moe_jitter_eps = moe_jitter_eps
+        self.moe_normalize_expert_weights = moe_normalize_expert_weights
+
+        self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False)
+
+    def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
+        if self.training and self.moe_jitter_eps is not None:
+            hidden_states *= torch.empty_like(hidden_states).uniform_(
+                1.0 - self.moe_jitter_eps, 1.0 + self.moe_jitter_eps
+            )
+        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+        weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32)
+        top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
+
+        top_weights_scale = (
+            torch.norm(top_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True)
+            if self.moe_normalize_expert_weights is not None
+            else 1.0
+        )
+        top_weights = top_weights / top_weights_scale
+
+        weights = weights.to(hidden_states.dtype)
+        top_weights = top_weights.to(hidden_states.dtype)
+        return weights, top_weights, top_experts
+
+
+class DbrxExpertGLU(nn.Module):
+    def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict):
+        super().__init__()
+        self.hidden_size = hidden_size
+        self.ffn_hidden_size = ffn_hidden_size
+        self.moe_num_experts = moe_num_experts
+
+        self.w1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
+        self.v1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
+        self.w2 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
+
+        act_fn_name = ffn_act_fn.get("name", "silu")
+        self.activation_fn = ACT2FN[act_fn_name]
+
+    def forward(
+        self, x: torch.Tensor, expert_w1: torch.Tensor, expert_v1: torch.Tensor, expert_w2: torch.Tensor
+    ) -> torch.Tensor:
+        gate_proj = x.matmul(expert_w1.t())
+        up_proj = x.matmul(expert_v1.t())
+        gate_proj = self.activation_fn(gate_proj)
+        intermediate_states = gate_proj * up_proj
+        down_proj = intermediate_states.matmul(expert_w2)
+        return down_proj
+
+
+class DbrxExperts(nn.Module):
+    def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict):
+        super().__init__()
+        self.moe_num_experts = moe_num_experts
+        self.mlp = DbrxExpertGLU(
+            hidden_size=hidden_size,
+            ffn_hidden_size=ffn_hidden_size,
+            moe_num_experts=moe_num_experts,
+            ffn_act_fn=ffn_act_fn,
+        )
+
+    def forward(
+        self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor
+    ) -> torch.Tensor:
+        bsz, q_len, hidden_size = x.shape
+        x = x.view(-1, hidden_size)
+        out = torch.zeros_like(x)
+
+        expert_mask = nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
+        # Chunk experts at once to avoid storing full parameter multiple times in autograd
+        w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
+            self.moe_num_experts, dim=0
+        )
+        v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
+            self.moe_num_experts, dim=0
+        )
+        w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
+            self.moe_num_experts, dim=0
+        )
+        w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked]
+        v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked]
+        w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked]
+        for expert_idx in range(0, self.moe_num_experts):
+            # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: dynamic shape operator: aten.nonzero.default`)
+            # (set torch._dynamo.config.capture_dynamic_output_shape_ops = True may help but not tested)
+            topk_idx, token_idx = torch.where(expert_mask[expert_idx])
+            if token_idx.shape[0] == 0:
+                continue
+
+            token_list = token_idx
+            topk_list = topk_idx
+
+            expert_tokens = x[None, token_list].reshape(-1, hidden_size)
+            expert_out = (
+                self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx])
+                * top_weights[token_list, topk_list, None]
+            )
+
+            out.index_add_(0, token_idx, expert_out)
+
+        out = out.reshape(bsz, q_len, hidden_size)
+        return out
+
+
+class DbrxFFN(nn.Module):
+    def __init__(self, config: DbrxConfig):
+        super().__init__()
+
+        ffn_config = config.ffn_config
+        self.router = DbrxRouter(
+            hidden_size=config.d_model,
+            moe_num_experts=ffn_config.moe_num_experts,
+            moe_top_k=ffn_config.moe_top_k,
+            moe_jitter_eps=ffn_config.moe_jitter_eps,
+            moe_normalize_expert_weights=ffn_config.moe_normalize_expert_weights,
+        )
+
+        self.experts = DbrxExperts(
+            hidden_size=config.d_model,
+            ffn_hidden_size=ffn_config.ffn_hidden_size,
+            moe_num_experts=ffn_config.moe_num_experts,
+            ffn_act_fn=ffn_config.ffn_act_fn,
+        )
+
+    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+        weights, top_weights, top_experts = self.router(x)
+        out = self.experts(x, weights, top_weights, top_experts)
+        return out, weights
+
+
+class DbrxBlock(GradientCheckpointingLayer):
+    def __init__(self, config: DbrxConfig, block_idx: int):
+        super().__init__()
+        self.hidden_size = config.d_model
+        self.resid_pdrop = config.resid_pdrop
+        self.block_idx = block_idx
+        self.norm_attn_norm = DbrxNormAttentionNorm(
+            config=config,
+            block_idx=block_idx,
+        )
+        self.ffn = DbrxFFN(config=config)
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        output_attentions: Optional[bool] = False,
+        output_router_logits: Optional[bool] = False,
+        use_cache: Optional[bool] = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs: Any,
+    ) -> Union[
+        tuple[torch.Tensor],
+        tuple[torch.Tensor, Optional[torch.Tensor]],
+        tuple[torch.Tensor, Optional[Cache]],
+        tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]],
+        tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
+        tuple[torch.Tensor, Optional[Cache], Optional[torch.Tensor]],
+        tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache], Optional[torch.Tensor]],
+    ]:
+        """Forward function for DbrxBlock.
+
+        Args:
+            hidden_states (`torch.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            position_ids (`torch.LongTensor`): position ids of shape `(batch, seq_len)`
+            attention_mask (`torch.Tensor`, *optional*): attention mask of size (batch_size, sequence_length)
+                if flash attention is used or (batch_size, 1, query_sequence_length, key_sequence_length)
+                if default attention is used.
+            past_key_values (`Cache`, *optional*): cached past key and value projection states
+            output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all
+                attention layers. See `attentions` under returned tensors for more detail.
+            output_router_logits (`bool`, *optional*): Whether or not to return the router logits.
+            use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are
+                returned and can be used to speed up decoding (see `past_key_values`).
+            cache_position (`torch.LongTensor`, *optional*): position ids of the cache
+        """
+
+        # Norm + Attention + Norm
+        resid_states, hidden_states, self_attn_weights = self.norm_attn_norm(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+            cache_position=cache_position,
+            **kwargs,
+        )
+
+        # Fully Connected
+        hidden_states, router_logits = self.ffn(hidden_states)
+        hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
+        hidden_states = resid_states + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights,)
+
+        if output_router_logits:
+            outputs += (router_logits,)
+
+        return outputs
+
+
+@auto_docstring
+class DbrxPreTrainedModel(PreTrainedModel):
+    config: DbrxConfig
+    base_model_prefix = "transformer"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["DbrxBlock"]
+    _skip_keys_device_placement = ["past_key_values"]
+    _supports_flash_attn = True
+    _supports_sdpa = True
+
+    _can_compile_fullgraph = False  # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
+
+    def _init_weights(self, module: nn.Module):
+        std = self.config.initializer_range
+        if isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.weight.data.fill_(1.0)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, DbrxExpertGLU):
+            module.w1.data.normal_(mean=0.0, std=std)
+            module.v1.data.normal_(mean=0.0, std=std)
+            module.w2.data.normal_(mean=0.0, std=std)
+
+
+@auto_docstring
+class DbrxModel(DbrxPreTrainedModel):
+    """Transformer decoder consisting of *config.num_hidden_layers*. Each layer is a [`DbrxBlock`] layer.
+
+    Args:
+        config ([`DbrxConfig`]): Model configuration class with all parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+    """
+
+    def __init__(self, config: DbrxConfig):
+        super().__init__(config)
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+        self.emb_pdrop = config.emb_pdrop
+
+        self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
+        self.blocks = nn.ModuleList([DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)])
+        self.norm_f = nn.LayerNorm(config.d_model, bias=False)
+        self.gradient_checkpointing = False
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self) -> nn.Embedding:
+        return self.wte
+
+    def set_input_embeddings(self, value: nn.Embedding):
+        self.wte = value
+
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        output_router_logits: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,  # NOOP kwargs, for now
+    ) -> Union[tuple, MoeModelOutputWithPast]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        output_router_logits = (
+            output_router_logits if output_router_logits is not None else self.config.output_router_logits
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if (input_ids is None) ^ (inputs_embeds is not None):
+            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+        if self.gradient_checkpointing and self.training and use_cache:
+            logger.warning_once(
+                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+            )
+            use_cache = False
+
+        if inputs_embeds is None:
+            inputs_embeds = self.wte(input_ids)
+
+        inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
+
+        if use_cache and past_key_values is None:
+            past_key_values = DynamicCache(config=self.config)
+
+        if cache_position is None:
+            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+            cache_position = torch.arange(
+                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+            )
+
+        if position_ids is None:
+            position_ids = cache_position.unsqueeze(0)
+
+        causal_mask = self._update_causal_mask(
+            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+        )
+
+        # embed positions
+        hidden_states = inputs_embeds
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_router_logits = () if output_router_logits else None
+
+        for block in self.blocks:
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            block_outputs = block(
+                hidden_states,
+                attention_mask=causal_mask,
+                position_ids=position_ids,
+                past_key_values=past_key_values,
+                output_attentions=output_attentions,
+                output_router_logits=output_router_logits,
+                use_cache=use_cache,
+                cache_position=cache_position,
+            )
+
+            hidden_states = block_outputs[0]
+
+            if output_attentions:
+                all_self_attns += (block_outputs[1],)
+
+            if output_router_logits:
+                all_router_logits += (block_outputs[-1],)
+
+        hidden_states = self.norm_f(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_router_logits]
+                if v is not None
+            )
+        return MoeModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=past_key_values,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            router_logits=all_router_logits,
+        )
+
+    # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+    def _update_causal_mask(
+        self,
+        attention_mask: Union[torch.Tensor, "BlockMask"],
+        input_tensor: torch.Tensor,
+        cache_position: torch.Tensor,
+        past_key_values: Cache,
+        output_attentions: bool = False,
+    ):
+        if self.config._attn_implementation == "flash_attention_2":
+            if attention_mask is not None and (attention_mask == 0.0).any():
+                return attention_mask
+            return None
+        if self.config._attn_implementation == "flex_attention":
+            if isinstance(attention_mask, torch.Tensor):
+                attention_mask = make_flex_block_causal_mask(attention_mask)
+            return attention_mask
+
+        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+        # to infer the attention mask.
+        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+        using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+        if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+            if AttentionMaskConverter._ignore_causal_mask_sdpa(
+                attention_mask,
+                inputs_embeds=input_tensor,
+                past_key_values_length=past_seen_tokens,
+                is_training=self.training,
+            ):
+                return None
+
+        dtype = input_tensor.dtype
+        sequence_length = input_tensor.shape[1]
+        if using_compilable_cache:
+            target_length = past_key_values.get_max_cache_shape()
+        else:
+            target_length = (
+                attention_mask.shape[-1]
+                if isinstance(attention_mask, torch.Tensor)
+                else past_seen_tokens + sequence_length + 1
+            )
+
+        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+            attention_mask,
+            sequence_length=sequence_length,
+            target_length=target_length,
+            dtype=dtype,
+            cache_position=cache_position,
+            batch_size=input_tensor.shape[0],
+        )
+
+        if (
+            self.config._attn_implementation == "sdpa"
+            and attention_mask is not None
+            and attention_mask.device.type in ["cuda", "xpu", "npu"]
+            and not output_attentions
+        ):
+            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+            # Details: https://github.com/pytorch/pytorch/issues/110213
+            min_dtype = torch.finfo(dtype).min
+            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+        return causal_mask
+
+    @staticmethod
+    # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+    def _prepare_4d_causal_attention_mask_with_cache_position(
+        attention_mask: torch.Tensor,
+        sequence_length: int,
+        target_length: int,
+        dtype: torch.dtype,
+        cache_position: torch.Tensor,
+        batch_size: int,
+        **kwargs,
+    ):
+        """
+        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+        Args:
+            attention_mask (`torch.Tensor`):
+                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+                `(batch_size, 1, query_length, key_value_length)`.
+            sequence_length (`int`):
+                The sequence length being processed.
+            target_length (`int`):
+                The target length: when generating with static cache, the mask should be as long as the static cache,
+                to account for the 0 padding, the part of the cache that is not filled yet.
+            dtype (`torch.dtype`):
+                The dtype to use for the 4D attention mask.
+            cache_position (`torch.Tensor`):
+                Indices depicting the position of the input sequence tokens in the sequence.
+            batch_size (`torch.Tensor`):
+                Batch size.
+        """
+        if attention_mask is not None and attention_mask.dim() == 4:
+            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+            causal_mask = attention_mask
+        else:
+            min_dtype = torch.finfo(dtype).min
+            causal_mask = torch.full(
+                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+            )
+            if sequence_length != 1:
+                causal_mask = torch.triu(causal_mask, diagonal=1)
+            causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+            if attention_mask is not None:
+                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
+                mask_length = attention_mask.shape[-1]
+                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+                    causal_mask.device
+                )
+                padding_mask = padding_mask == 0
+                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+                    padding_mask, min_dtype
+                )
+
+        return causal_mask
+
+
+@auto_docstring(
+    custom_intro="""
+    The DBRX Model transformer for causal language modeling.
+    """
+)
+class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin):
+    def __init__(self, config: DbrxConfig):
+        super().__init__(config)
+        self.transformer = DbrxModel(config)
+        self.vocab_size = config.vocab_size
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+        self.moe_loss_weight = config.ffn_config.moe_loss_weight
+        self.num_experts = config.ffn_config.moe_num_experts
+        self.num_experts_per_tok = config.ffn_config.moe_top_k
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self) -> nn.Embedding:
+        return self.transformer.get_input_embeddings()
+
+    def set_input_embeddings(self, value: nn.Embedding):
+        self.transformer.set_input_embeddings(value)
+
+    def get_output_embeddings(self) -> nn.Linear:
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings: nn.Linear):
+        self.lm_head = new_embeddings
+
+    def set_decoder(self, decoder: DbrxModel):
+        self.transformer = decoder
+
+    def get_decoder(self) -> DbrxModel:
+        return self.transformer
+
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        output_router_logits: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        logits_to_keep: Union[int, torch.Tensor] = 0,
+        **kwargs,
+    ) -> Union[tuple, MoeCausalLMOutputWithPast]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+        Example:
+
+        ```python
+        >> from transformers import AutoTokenizer, DbrxForCausalLM
+
+        >> model = DbrxForCausalLM.from_pretrained("databricks/dbrx-instruct")
+        >> tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct")
+
+        >> prompt = "Hey, are you conscious? Can you talk to me?"
+        >> inputs = tokenizer(prompt, return_tensors="pt")
+
+        >> # Generate
+        >> generate_ids = model.generate(inputs.input_ids, max_length=30)
+        >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+        ```
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        output_router_logits = (
+            output_router_logits if output_router_logits is not None else self.config.output_router_logits
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+        outputs = self.transformer(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            output_router_logits=output_router_logits,
+            return_dict=return_dict,
+            cache_position=cache_position,
+        )
+
+        hidden_states = outputs[0]
+        # No upscaling to float was ever done for Dbrx
+        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+        logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(
+                logits,
+                labels,
+                vocab_size=self.config.vocab_size,
+                **kwargs,
+            )
+
+        aux_loss = None
+        if output_router_logits:
+            aux_loss = load_balancing_loss_func(
+                outputs.router_logits if return_dict else outputs[-1],
+                self.num_experts,
+                self.num_experts_per_tok,
+                attention_mask,
+            )
+            if labels is not None and loss is not None:
+                loss += self.moe_loss_weight * aux_loss.to(loss.device)  # make sure to reside in the same device
+
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            if output_router_logits:
+                output = (aux_loss,) + output
+            return (loss,) + output if loss is not None else output
+
+        return MoeCausalLMOutputWithPast(
+            loss=loss,
+            aux_loss=aux_loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            router_logits=outputs.router_logits,
+        )
+
+
+__all__ = ["DbrxForCausalLM", "DbrxModel", "DbrxPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..98236a86d7a1e8b4ff16b53fb3ff37befbf1d7ac
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_deit import *
+    from .feature_extraction_deit import *
+    from .image_processing_deit import *
+    from .image_processing_deit_fast import *
+    from .modeling_deit import *
+    from .modeling_tf_deit import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/configuration_deit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/configuration_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a321ebe293e191e7bbce29b528dfa2f6b00d141
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/configuration_deit.py
@@ -0,0 +1,152 @@
+# coding=utf-8
+# Copyright 2021 Facebook AI Research (FAIR) and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DeiT model configuration"""
+
+from collections import OrderedDict
+from collections.abc import Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeiTConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`DeiTModel`]. It is used to instantiate an DeiT
+    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+    defaults will yield a similar configuration to that of the DeiT
+    [facebook/deit-base-distilled-patch16-224](https://huggingface.co/facebook/deit-base-distilled-patch16-224)
+    architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        hidden_size (`int`, *optional*, defaults to 768):
+            Dimensionality of the encoder layers and the pooler layer.
+        num_hidden_layers (`int`, *optional*, defaults to 12):
+            Number of hidden layers in the Transformer encoder.
+        num_attention_heads (`int`, *optional*, defaults to 12):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        intermediate_size (`int`, *optional*, defaults to 3072):
+            Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"selu"` and `"gelu_new"` are supported.
+        hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+            The epsilon used by the layer normalization layers.
+        image_size (`int`, *optional*, defaults to 224):
+            The size (resolution) of each image.
+        patch_size (`int`, *optional*, defaults to 16):
+            The size (resolution) of each patch.
+        num_channels (`int`, *optional*, defaults to 3):
+            The number of input channels.
+        qkv_bias (`bool`, *optional*, defaults to `True`):
+            Whether to add a bias to the queries, keys and values.
+        encoder_stride (`int`, *optional*, defaults to 16):
+            Factor to increase the spatial resolution by in the decoder head for masked image modeling.
+        pooler_output_size (`int`, *optional*):
+           Dimensionality of the pooler layer. If None, defaults to `hidden_size`.
+        pooler_act (`str`, *optional*, defaults to `"tanh"`):
+           The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and
+           Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are
+           supported for Tensorflow.
+
+    Example:
+
+    ```python
+    >>> from transformers import DeiTConfig, DeiTModel
+
+    >>> # Initializing a DeiT deit-base-distilled-patch16-224 style configuration
+    >>> configuration = DeiTConfig()
+
+    >>> # Initializing a model (with random weights) from the deit-base-distilled-patch16-224 style configuration
+    >>> model = DeiTModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "deit"
+
+    def __init__(
+        self,
+        hidden_size=768,
+        num_hidden_layers=12,
+        num_attention_heads=12,
+        intermediate_size=3072,
+        hidden_act="gelu",
+        hidden_dropout_prob=0.0,
+        attention_probs_dropout_prob=0.0,
+        initializer_range=0.02,
+        layer_norm_eps=1e-12,
+        image_size=224,
+        patch_size=16,
+        num_channels=3,
+        qkv_bias=True,
+        encoder_stride=16,
+        pooler_output_size=None,
+        pooler_act="tanh",
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.hidden_size = hidden_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.intermediate_size = intermediate_size
+        self.hidden_act = hidden_act
+        self.hidden_dropout_prob = hidden_dropout_prob
+        self.attention_probs_dropout_prob = attention_probs_dropout_prob
+        self.initializer_range = initializer_range
+        self.layer_norm_eps = layer_norm_eps
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.qkv_bias = qkv_bias
+        self.encoder_stride = encoder_stride
+        self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size
+        self.pooler_act = pooler_act
+
+
+class DeiTOnnxConfig(OnnxConfig):
+    torch_onnx_minimum_version = version.parse("1.11")
+
+    @property
+    def inputs(self) -> Mapping[str, Mapping[int, str]]:
+        return OrderedDict(
+            [
+                ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
+            ]
+        )
+
+    @property
+    def atol_for_validation(self) -> float:
+        return 1e-4
+
+
+__all__ = ["DeiTConfig", "DeiTOnnxConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/feature_extraction_deit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/feature_extraction_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..d040fd08395f8e921ec688228d7d5faa8963ab81
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/feature_extraction_deit.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for DeiT."""
+
+import warnings
+
+from ...utils import logging
+from ...utils.import_utils import requires
+from .image_processing_deit import DeiTImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class DeiTFeatureExtractor(DeiTImageProcessor):
+    def __init__(self, *args, **kwargs) -> None:
+        warnings.warn(
+            "The class DeiTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
+            " use DeiTImageProcessor instead.",
+            FutureWarning,
+        )
+        super().__init__(*args, **kwargs)
+
+
+__all__ = ["DeiTFeatureExtractor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/image_processing_deit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/image_processing_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e2f6c3b5ae5f0f1cf2eb1727d2e3235443b81b9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/image_processing_deit.py
@@ -0,0 +1,301 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for DeiT."""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import resize, to_channel_dimension_format
+from ...image_utils import (
+    IMAGENET_STANDARD_MEAN,
+    IMAGENET_STANDARD_STD,
+    ChannelDimension,
+    ImageInput,
+    PILImageResampling,
+    infer_channel_dimension_format,
+    is_scaled_image,
+    make_flat_list_of_images,
+    to_numpy_array,
+    valid_images,
+    validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+from ...utils.import_utils import requires
+
+
+if is_vision_available():
+    import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class DeiTImageProcessor(BaseImageProcessor):
+    r"""
+    Constructs a DeiT image processor.
+
+    Args:
+        do_resize (`bool`, *optional*, defaults to `True`):
+            Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+            `do_resize` in `preprocess`.
+        size (`dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
+            Size of the image after `resize`. Can be overridden by `size` in `preprocess`.
+        resample (`PILImageResampling` filter, *optional*, defaults to `Resampling.BICUBIC`):
+            Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
+        do_center_crop (`bool`, *optional*, defaults to `True`):
+            Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
+            is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in `preprocess`.
+        crop_size (`dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
+            Desired output size when applying center-cropping. Can be overridden by `crop_size` in `preprocess`.
+        rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+            Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+            `preprocess` method.
+        do_rescale (`bool`, *optional*, defaults to `True`):
+            Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+            parameter in the `preprocess` method.
+        do_normalize (`bool`, *optional*, defaults to `True`):
+            Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+            method.
+        image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+            Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+            channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+        image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+            Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+            number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+    """
+
+    model_input_names = ["pixel_values"]
+
+    def __init__(
+        self,
+        do_resize: bool = True,
+        size: Optional[dict[str, int]] = None,
+        resample: PILImageResampling = PIL.Image.BICUBIC,
+        do_center_crop: bool = True,
+        crop_size: Optional[dict[str, int]] = None,
+        rescale_factor: Union[int, float] = 1 / 255,
+        do_rescale: bool = True,
+        do_normalize: bool = True,
+        image_mean: Optional[Union[float, list[float]]] = None,
+        image_std: Optional[Union[float, list[float]]] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        size = size if size is not None else {"height": 256, "width": 256}
+        size = get_size_dict(size)
+        crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+        crop_size = get_size_dict(crop_size, param_name="crop_size")
+
+        self.do_resize = do_resize
+        self.size = size
+        self.resample = resample
+        self.do_center_crop = do_center_crop
+        self.crop_size = crop_size
+        self.do_rescale = do_rescale
+        self.rescale_factor = rescale_factor
+        self.do_normalize = do_normalize
+        self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+        self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+    # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
+    def resize(
+        self,
+        image: np.ndarray,
+        size: dict[str, int],
+        resample: PILImageResampling = PILImageResampling.BICUBIC,
+        data_format: Optional[Union[str, ChannelDimension]] = None,
+        input_data_format: Optional[Union[str, ChannelDimension]] = None,
+        **kwargs,
+    ) -> np.ndarray:
+        """
+        Resize an image to `(size["height"], size["width"])`.
+
+        Args:
+            image (`np.ndarray`):
+                Image to resize.
+            size (`dict[str, int]`):
+                Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+            resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+                `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
+            data_format (`ChannelDimension` or `str`, *optional*):
+                The channel dimension format for the output image. If unset, the channel dimension format of the input
+                image is used. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+            input_data_format (`ChannelDimension` or `str`, *optional*):
+                The channel dimension format for the input image. If unset, the channel dimension format is inferred
+                from the input image. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+        Returns:
+            `np.ndarray`: The resized image.
+        """
+        size = get_size_dict(size)
+        if "height" not in size or "width" not in size:
+            raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
+        output_size = (size["height"], size["width"])
+        return resize(
+            image,
+            size=output_size,
+            resample=resample,
+            data_format=data_format,
+            input_data_format=input_data_format,
+            **kwargs,
+        )
+
+    @filter_out_non_signature_kwargs()
+    def preprocess(
+        self,
+        images: ImageInput,
+        do_resize: Optional[bool] = None,
+        size: Optional[dict[str, int]] = None,
+        resample=None,
+        do_center_crop: Optional[bool] = None,
+        crop_size: Optional[dict[str, int]] = None,
+        do_rescale: Optional[bool] = None,
+        rescale_factor: Optional[float] = None,
+        do_normalize: Optional[bool] = None,
+        image_mean: Optional[Union[float, list[float]]] = None,
+        image_std: Optional[Union[float, list[float]]] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        data_format: ChannelDimension = ChannelDimension.FIRST,
+        input_data_format: Optional[Union[str, ChannelDimension]] = None,
+    ) -> PIL.Image.Image:
+        """
+        Preprocess an image or batch of images.
+
+        Args:
+            images (`ImageInput`):
+                Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+                passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+            do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+                Whether to resize the image.
+            size (`dict[str, int]`, *optional*, defaults to `self.size`):
+                Size of the image after `resize`.
+            resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+                PILImageResampling filter to use if resizing the image Only has an effect if `do_resize` is set to
+                `True`.
+            do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+                Whether to center crop the image.
+            crop_size (`dict[str, int]`, *optional*, defaults to `self.crop_size`):
+                Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
+                padded with zeros and then cropped
+            do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+                Whether to rescale the image values between [0 - 1].
+            rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+                Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+            do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+                Whether to normalize the image.
+            image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+                Image mean.
+            image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+                Image standard deviation.
+            return_tensors (`str` or `TensorType`, *optional*):
+                The type of tensors to return. Can be one of:
+                    - `None`: Return a list of `np.ndarray`.
+                    - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+                    - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+                    - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+                    - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+            data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+                The channel dimension format for the output image. Can be one of:
+                    - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                    - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+            input_data_format (`ChannelDimension` or `str`, *optional*):
+                The channel dimension format for the input image. If unset, the channel dimension format is inferred
+                from the input image. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+        """
+        do_resize = do_resize if do_resize is not None else self.do_resize
+        resample = resample if resample is not None else self.resample
+        do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+        do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+        rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+        do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+        image_mean = image_mean if image_mean is not None else self.image_mean
+        image_std = image_std if image_std is not None else self.image_std
+
+        size = size if size is not None else self.size
+        size = get_size_dict(size)
+        crop_size = crop_size if crop_size is not None else self.crop_size
+        crop_size = get_size_dict(crop_size, param_name="crop_size")
+
+        images = make_flat_list_of_images(images)
+
+        if not valid_images(images):
+            raise ValueError(
+                "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+                "torch.Tensor, tf.Tensor or jax.ndarray."
+            )
+        validate_preprocess_arguments(
+            do_rescale=do_rescale,
+            rescale_factor=rescale_factor,
+            do_normalize=do_normalize,
+            image_mean=image_mean,
+            image_std=image_std,
+            do_center_crop=do_center_crop,
+            crop_size=crop_size,
+            do_resize=do_resize,
+            size=size,
+            resample=resample,
+        )
+        # All transformations expect numpy arrays.
+        images = [to_numpy_array(image) for image in images]
+
+        if do_rescale and is_scaled_image(images[0]):
+            logger.warning_once(
+                "It looks like you are trying to rescale already rescaled images. If the input"
+                " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+            )
+
+        if input_data_format is None:
+            # We assume that all images have the same channel dimension format.
+            input_data_format = infer_channel_dimension_format(images[0])
+
+        all_images = []
+        for image in images:
+            if do_resize:
+                image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+
+            if do_center_crop:
+                image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
+
+            if do_rescale:
+                image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+            if do_normalize:
+                image = self.normalize(
+                    image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+                )
+
+            all_images.append(image)
+        images = [
+            to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+            for image in all_images
+        ]
+
+        data = {"pixel_values": images}
+        return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["DeiTImageProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/image_processing_deit_fast.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/image_processing_deit_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..3aafeaf50c09455cffeecb3776eb3598c8ceccf2
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/image_processing_deit_fast.py
@@ -0,0 +1,41 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for DeiT."""
+
+from ...image_processing_utils_fast import BaseImageProcessorFast
+from ...image_utils import (
+    IMAGENET_STANDARD_MEAN,
+    IMAGENET_STANDARD_STD,
+    PILImageResampling,
+)
+from ...utils import auto_docstring
+
+
+@auto_docstring
+class DeiTImageProcessorFast(BaseImageProcessorFast):
+    # To be checked against the slow image processor
+    # None values left after checking can be removed
+    resample = PILImageResampling.BICUBIC
+    image_mean = IMAGENET_STANDARD_MEAN
+    image_std = IMAGENET_STANDARD_STD
+    size = {"height": 256, "width": 256}
+    crop_size = {"height": 224, "width": 224}
+    do_resize = True
+    do_center_crop = True
+    do_rescale = True
+    do_normalize = True
+
+
+__all__ = ["DeiTImageProcessorFast"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/modeling_deit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/modeling_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddb03c053f1ee08f011e650daad794821205ff33
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/modeling_deit.py
@@ -0,0 +1,791 @@
+# coding=utf-8
+# Copyright 2021 Facebook AI Research (FAIR), Ross Wightman, The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DeiT model."""
+
+import collections.abc
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+    BaseModelOutput,
+    BaseModelOutputWithPooling,
+    ImageClassifierOutput,
+    MaskedImageModelingOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging, torch_int
+from ...utils.generic import can_return_tuple, check_model_inputs
+from .configuration_deit import DeiTConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class DeiTEmbeddings(nn.Module):
+    """
+    Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token.
+    """
+
+    def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None:
+        super().__init__()
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+        self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+        self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
+        self.patch_embeddings = DeiTPatchEmbeddings(config)
+        num_patches = self.patch_embeddings.num_patches
+        self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+        self.patch_size = config.patch_size
+
+    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+        """
+        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+        images. This method is also adapted to support torch.jit tracing and 2 class embeddings.
+
+        Adapted from:
+        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+        """
+
+        num_patches = embeddings.shape[1] - 2
+        num_positions = self.position_embeddings.shape[1] - 2
+
+        # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+            return self.position_embeddings
+
+        class_and_dist_pos_embed = self.position_embeddings[:, :2]
+        patch_pos_embed = self.position_embeddings[:, 2:]
+
+        dim = embeddings.shape[-1]
+
+        new_height = height // self.patch_size
+        new_width = width // self.patch_size
+
+        sqrt_num_positions = torch_int(num_positions**0.5)
+        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed,
+            size=(new_height, new_width),
+            mode="bicubic",
+            align_corners=False,
+        )
+
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+        return torch.cat((class_and_dist_pos_embed, patch_pos_embed), dim=1)
+
+    def forward(
+        self,
+        pixel_values: torch.Tensor,
+        bool_masked_pos: Optional[torch.BoolTensor] = None,
+        interpolate_pos_encoding: bool = False,
+    ) -> torch.Tensor:
+        _, _, height, width = pixel_values.shape
+        embeddings = self.patch_embeddings(pixel_values)
+
+        batch_size, seq_length, _ = embeddings.size()
+
+        if bool_masked_pos is not None:
+            mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
+            # replace the masked visual tokens by mask_tokens
+            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+
+        distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
+
+        embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
+        position_embedding = self.position_embeddings
+
+        if interpolate_pos_encoding:
+            position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
+
+        embeddings = embeddings + position_embedding
+        embeddings = self.dropout(embeddings)
+        return embeddings
+
+
+class DeiTPatchEmbeddings(nn.Module):
+    """
+    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+    Transformer.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        image_size, patch_size = config.image_size, config.patch_size
+        num_channels, hidden_size = config.num_channels, config.hidden_size
+
+        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.num_patches = num_patches
+
+        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+        batch_size, num_channels, height, width = pixel_values.shape
+        if num_channels != self.num_channels:
+            raise ValueError(
+                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+            )
+        x = self.projection(pixel_values).flatten(2).transpose(1, 2)
+        return x
+
+
+# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
+def eager_attention_forward(
+    module: nn.Module,
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attention_mask: Optional[torch.Tensor],
+    scaling: float,
+    dropout: float = 0.0,
+    **kwargs,
+):
+    # Take the dot product between "query" and "key" to get the raw attention scores.
+    attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+
+    # Normalize the attention scores to probabilities.
+    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+
+    # This is actually dropping out entire tokens to attend to, which might
+    # seem a bit unusual, but is taken from the original Transformer paper.
+    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+    # Mask heads if we want to
+    if attention_mask is not None:
+        attn_weights = attn_weights * attention_mask
+
+    attn_output = torch.matmul(attn_weights, value)
+    attn_output = attn_output.transpose(1, 2).contiguous()
+
+    return attn_output, attn_weights
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT
+class DeiTSelfAttention(nn.Module):
+    def __init__(self, config: DeiTConfig):
+        super().__init__()
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+            raise ValueError(
+                f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
+                f"heads {config.num_attention_heads}."
+            )
+
+        self.config = config
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+        self.dropout_prob = config.attention_probs_dropout_prob
+        self.scaling = self.attention_head_size**-0.5
+        self.is_causal = False
+
+        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+    def forward(
+        self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        batch_size = hidden_states.shape[0]
+        new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
+
+        key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
+        value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
+        query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
+
+        attention_interface: Callable = eager_attention_forward
+        if self.config._attn_implementation != "eager":
+            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+        context_layer, attention_probs = attention_interface(
+            self,
+            query_layer,
+            key_layer,
+            value_layer,
+            head_mask,
+            is_causal=self.is_causal,
+            scaling=self.scaling,
+            dropout=0.0 if not self.training else self.dropout_prob,
+        )
+
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.reshape(new_context_layer_shape)
+
+        return context_layer, attention_probs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT
+class DeiTSelfOutput(nn.Module):
+    """
+    The residual connection is defined in DeiTLayer instead of here (as is the case with other models), due to the
+    layernorm applied before each block.
+    """
+
+    def __init__(self, config: DeiTConfig):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->DeiT
+class DeiTAttention(nn.Module):
+    def __init__(self, config: DeiTConfig):
+        super().__init__()
+        self.attention = DeiTSelfAttention(config)
+        self.output = DeiTSelfOutput(config)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads: set[int]):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+        )
+
+        # Prune linear layers
+        self.attention.query = prune_linear_layer(self.attention.query, index)
+        self.attention.key = prune_linear_layer(self.attention.key, index)
+        self.attention.value = prune_linear_layer(self.attention.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+        self_attn_output, _ = self.attention(hidden_states, head_mask)
+        output = self.output(self_attn_output, hidden_states)
+        return output
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
+class DeiTIntermediate(nn.Module):
+    def __init__(self, config: DeiTConfig):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+        return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DeiT
+class DeiTOutput(nn.Module):
+    def __init__(self, config: DeiTConfig):
+        super().__init__()
+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = hidden_states + input_tensor
+        return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT
+class DeiTLayer(GradientCheckpointingLayer):
+    """This corresponds to the Block class in the timm implementation."""
+
+    def __init__(self, config: DeiTConfig):
+        super().__init__()
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.seq_len_dim = 1
+        self.attention = DeiTAttention(config)
+        self.intermediate = DeiTIntermediate(config)
+        self.output = DeiTOutput(config)
+        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+    def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+        hidden_states_norm = self.layernorm_before(hidden_states)
+        attention_output = self.attention(hidden_states_norm, head_mask)
+
+        # first residual connection
+        hidden_states = attention_output + hidden_states
+
+        # in DeiT, layernorm is also applied after self-attention
+        layer_output = self.layernorm_after(hidden_states)
+        layer_output = self.intermediate(layer_output)
+
+        # second residual connection is done here
+        layer_output = self.output(layer_output, hidden_states)
+
+        return layer_output
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->DeiT
+class DeiTEncoder(nn.Module):
+    def __init__(self, config: DeiTConfig):
+        super().__init__()
+        self.config = config
+        self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)])
+        self.gradient_checkpointing = False
+
+    def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> BaseModelOutput:
+        for i, layer_module in enumerate(self.layer):
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+            hidden_states = layer_module(hidden_states, layer_head_mask)
+
+        return BaseModelOutput(last_hidden_state=hidden_states)
+
+
+@auto_docstring
+class DeiTPreTrainedModel(PreTrainedModel):
+    config: DeiTConfig
+    base_model_prefix = "deit"
+    main_input_name = "pixel_values"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["DeiTLayer"]
+    _supports_sdpa = True
+    _supports_flash_attn = True
+    _supports_flex_attn = True
+    _supports_attention_backend = True
+    _can_record_outputs = {
+        "hidden_states": DeiTLayer,
+        "attentions": DeiTSelfAttention,
+    }
+
+    def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+            # `trunc_normal_cpu` not implemented in `half` issues
+            module.weight.data = nn.init.trunc_normal_(
+                module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+            ).to(module.weight.dtype)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, DeiTEmbeddings):
+            module.cls_token.data.zero_()
+            module.position_embeddings.data.zero_()
+            module.distillation_token.data.zero_()
+            if module.mask_token is not None:
+                module.mask_token.data.zero_()
+
+
+@auto_docstring
+class DeiTModel(DeiTPreTrainedModel):
+    def __init__(self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None:
+        r"""
+        add_pooling_layer (bool, *optional*, defaults to `True`):
+            Whether to add a pooling layer
+        use_mask_token (`bool`, *optional*, defaults to `False`):
+            Whether to use a mask token for masked image modeling.
+        """
+        super().__init__(config)
+        self.config = config
+
+        self.embeddings = DeiTEmbeddings(config, use_mask_token=use_mask_token)
+        self.encoder = DeiTEncoder(config)
+
+        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.pooler = DeiTPooler(config) if add_pooling_layer else None
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self) -> DeiTPatchEmbeddings:
+        return self.embeddings.patch_embeddings
+
+    def _prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    @check_model_inputs(tie_last_hidden_states=False)
+    @auto_docstring
+    def forward(
+        self,
+        pixel_values: Optional[torch.Tensor] = None,
+        bool_masked_pos: Optional[torch.BoolTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        interpolate_pos_encoding: bool = False,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> BaseModelOutputWithPooling:
+        r"""
+        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
+            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+        """
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+        # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
+        expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
+        if pixel_values.dtype != expected_dtype:
+            pixel_values = pixel_values.to(expected_dtype)
+
+        embedding_output = self.embeddings(
+            pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+        )
+
+        encoder_outputs: BaseModelOutput = self.encoder(embedding_output, head_mask=head_mask)
+        sequence_output = encoder_outputs.last_hidden_state
+        sequence_output = self.layernorm(sequence_output)
+        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+        return BaseModelOutputWithPooling(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+        )
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->DeiT
+class DeiTPooler(nn.Module):
+    def __init__(self, config: DeiTConfig):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
+        self.activation = ACT2FN[config.pooler_act]
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(first_token_tensor)
+        pooled_output = self.activation(pooled_output)
+        return pooled_output
+
+
+@auto_docstring(
+    custom_intro="""
+    DeiT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://huggingface.co/papers/2111.09886).
+
+    
+
+    Note that we provide a script to pre-train this model on custom data in our [examples
+    directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
+
+    
+    """
+)
+class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
+    def __init__(self, config: DeiTConfig) -> None:
+        super().__init__(config)
+
+        self.deit = DeiTModel(config, add_pooling_layer=False, use_mask_token=True)
+
+        self.decoder = nn.Sequential(
+            nn.Conv2d(
+                in_channels=config.hidden_size,
+                out_channels=config.encoder_stride**2 * config.num_channels,
+                kernel_size=1,
+            ),
+            nn.PixelShuffle(config.encoder_stride),
+        )
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @can_return_tuple
+    @auto_docstring
+    def forward(
+        self,
+        pixel_values: Optional[torch.Tensor] = None,
+        bool_masked_pos: Optional[torch.BoolTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        interpolate_pos_encoding: bool = False,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> MaskedImageModelingOutput:
+        r"""
+        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
+            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+
+        Examples:
+        ```python
+        >>> from transformers import AutoImageProcessor, DeiTForMaskedImageModeling
+        >>> import torch
+        >>> from PIL import Image
+        >>> import requests
+
+        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw)
+
+        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
+        >>> model = DeiTForMaskedImageModeling.from_pretrained("facebook/deit-base-distilled-patch16-224")
+
+        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
+        >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
+        >>> # create random boolean mask of shape (batch_size, num_patches)
+        >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
+
+        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
+        >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
+        >>> list(reconstructed_pixel_values.shape)
+        [1, 3, 224, 224]
+        ```"""
+
+        outputs: BaseModelOutputWithPooling = self.deit(
+            pixel_values,
+            bool_masked_pos=bool_masked_pos,
+            head_mask=head_mask,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+            **kwargs,
+        )
+
+        sequence_output = outputs.last_hidden_state
+
+        # Reshape to (batch_size, num_channels, height, width)
+        sequence_output = sequence_output[:, 1:-1]
+        batch_size, sequence_length, num_channels = sequence_output.shape
+        height = width = int(sequence_length**0.5)
+        sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
+
+        # Reconstruct pixel values
+        reconstructed_pixel_values = self.decoder(sequence_output)
+
+        masked_im_loss = None
+        if bool_masked_pos is not None:
+            size = self.config.image_size // self.config.patch_size
+            bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
+            mask = (
+                bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
+                .repeat_interleave(self.config.patch_size, 2)
+                .unsqueeze(1)
+                .contiguous()
+            )
+            reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
+            masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
+
+        return MaskedImageModelingOutput(
+            loss=masked_im_loss,
+            reconstruction=reconstructed_pixel_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+@auto_docstring(
+    custom_intro="""
+    DeiT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
+    the [CLS] token) e.g. for ImageNet.
+    """
+)
+class DeiTForImageClassification(DeiTPreTrainedModel):
+    def __init__(self, config: DeiTConfig) -> None:
+        super().__init__(config)
+
+        self.num_labels = config.num_labels
+        self.deit = DeiTModel(config, add_pooling_layer=False)
+
+        # Classifier head
+        self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @can_return_tuple
+    @auto_docstring
+    def forward(
+        self,
+        pixel_values: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        labels: Optional[torch.Tensor] = None,
+        interpolate_pos_encoding: bool = False,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> ImageClassifierOutput:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+        Examples:
+
+        ```python
+        >>> from transformers import AutoImageProcessor, DeiTForImageClassification
+        >>> import torch
+        >>> from PIL import Image
+        >>> import requests
+
+        >>> torch.manual_seed(3)  # doctest: +IGNORE_RESULT
+        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw)
+
+        >>> # note: we are loading a DeiTForImageClassificationWithTeacher from the hub here,
+        >>> # so the head will be randomly initialized, hence the predictions will be random
+        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
+        >>> model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")
+
+        >>> inputs = image_processor(images=image, return_tensors="pt")
+        >>> outputs = model(**inputs)
+        >>> logits = outputs.logits
+        >>> # model predicts one of the 1000 ImageNet classes
+        >>> predicted_class_idx = logits.argmax(-1).item()
+        >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
+        Predicted class: Polaroid camera, Polaroid Land camera
+        ```"""
+
+        outputs: BaseModelOutputWithPooling = self.deit(
+            pixel_values,
+            head_mask=head_mask,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+            **kwargs,
+        )
+
+        sequence_output = outputs.last_hidden_state
+
+        logits = self.classifier(sequence_output[:, 0, :])
+        # we don't use the distillation token
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(labels, logits, self.config, **kwargs)
+
+        return ImageClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+@dataclass
+@auto_docstring(
+    custom_intro="""
+    Output type of [`DeiTForImageClassificationWithTeacher`].
+    """
+)
+class DeiTForImageClassificationWithTeacherOutput(ModelOutput):
+    r"""
+    logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+        Prediction scores as the average of the cls_logits and distillation logits.
+    cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+        Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
+        class token).
+    distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+        Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
+        distillation token).
+    """
+
+    logits: Optional[torch.FloatTensor] = None
+    cls_logits: Optional[torch.FloatTensor] = None
+    distillation_logits: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor]] = None
+    attentions: Optional[tuple[torch.FloatTensor]] = None
+
+
+@auto_docstring(
+    custom_intro="""
+    DeiT Model transformer with image classification heads on top (a linear layer on top of the final hidden state of
+    the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.
+
+    .. warning::
+
+           This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
+           supported.
+    """
+)
+class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
+    def __init__(self, config: DeiTConfig) -> None:
+        super().__init__(config)
+
+        self.num_labels = config.num_labels
+        self.deit = DeiTModel(config, add_pooling_layer=False)
+
+        # Classifier heads
+        self.cls_classifier = (
+            nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+        )
+        self.distillation_classifier = (
+            nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+        )
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @can_return_tuple
+    @auto_docstring
+    def forward(
+        self,
+        pixel_values: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        interpolate_pos_encoding: bool = False,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> DeiTForImageClassificationWithTeacherOutput:
+        outputs: BaseModelOutputWithPooling = self.deit(
+            pixel_values,
+            head_mask=head_mask,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+            **kwargs,
+        )
+
+        sequence_output = outputs.last_hidden_state
+
+        cls_logits = self.cls_classifier(sequence_output[:, 0, :])
+        distillation_logits = self.distillation_classifier(sequence_output[:, 1, :])
+
+        # during inference, return the average of both classifier predictions
+        logits = (cls_logits + distillation_logits) / 2
+
+        return DeiTForImageClassificationWithTeacherOutput(
+            logits=logits,
+            cls_logits=cls_logits,
+            distillation_logits=distillation_logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+__all__ = [
+    "DeiTForImageClassification",
+    "DeiTForImageClassificationWithTeacher",
+    "DeiTForMaskedImageModeling",
+    "DeiTModel",
+    "DeiTPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/modeling_tf_deit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/modeling_tf_deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c56eee87911edc445641e0bbc14f094e1c5efa7
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deit/modeling_tf_deit.py
@@ -0,0 +1,1232 @@
+# coding=utf-8
+# Copyright 2022 Facebook AI Research (FAIR) and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TensorFlow DeiT model."""
+
+from __future__ import annotations
+
+import collections.abc
+import math
+from dataclasses import dataclass
+
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+    TFBaseModelOutput,
+    TFBaseModelOutputWithPooling,
+    TFImageClassifierOutput,
+    TFMaskedImageModelingOutput,
+)
+from ...modeling_tf_utils import (
+    TFPreTrainedModel,
+    TFSequenceClassificationLoss,
+    get_initializer,
+    keras,
+    keras_serializable,
+    unpack_inputs,
+)
+from ...tf_utils import shape_list, stable_softmax
+from ...utils import (
+    ModelOutput,
+    add_code_sample_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    logging,
+    replace_return_docstrings,
+)
+from .configuration_deit import DeiTConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "DeiTConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/deit-base-distilled-patch16-224"
+_EXPECTED_OUTPUT_SHAPE = [1, 198, 768]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/deit-base-distilled-patch16-224"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+@dataclass
+class TFDeiTForImageClassificationWithTeacherOutput(ModelOutput):
+    """
+    Output type of [`DeiTForImageClassificationWithTeacher`].
+
+    Args:
+        logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+            Prediction scores as the average of the cls_logits and distillation logits.
+        cls_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+            Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
+            class token).
+        distillation_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+            Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
+            distillation token).
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
+            the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+            the self-attention heads.
+    """
+
+    logits: tf.Tensor | None = None
+    cls_logits: tf.Tensor | None = None
+    distillation_logits: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+class TFDeiTEmbeddings(keras.layers.Layer):
+    """
+    Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token.
+    """
+
+    def __init__(self, config: DeiTConfig, use_mask_token: bool = False, **kwargs) -> None:
+        super().__init__(**kwargs)
+        self.config = config
+        self.use_mask_token = use_mask_token
+        self.patch_embeddings = TFDeiTPatchEmbeddings(config=config, name="patch_embeddings")
+        self.dropout = keras.layers.Dropout(config.hidden_dropout_prob, name="dropout")
+
+    def build(self, input_shape=None):
+        self.cls_token = self.add_weight(
+            shape=(1, 1, self.config.hidden_size),
+            initializer=keras.initializers.zeros(),
+            trainable=True,
+            name="cls_token",
+        )
+        self.distillation_token = self.add_weight(
+            shape=(1, 1, self.config.hidden_size),
+            initializer=keras.initializers.zeros(),
+            trainable=True,
+            name="distillation_token",
+        )
+        self.mask_token = None
+        if self.use_mask_token:
+            self.mask_token = self.add_weight(
+                shape=(1, 1, self.config.hidden_size),
+                initializer=keras.initializers.zeros(),
+                trainable=True,
+                name="mask_token",
+            )
+        num_patches = self.patch_embeddings.num_patches
+        self.position_embeddings = self.add_weight(
+            shape=(1, num_patches + 2, self.config.hidden_size),
+            initializer=keras.initializers.zeros(),
+            trainable=True,
+            name="position_embeddings",
+        )
+
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "patch_embeddings", None) is not None:
+            with tf.name_scope(self.patch_embeddings.name):
+                self.patch_embeddings.build(None)
+        if getattr(self, "dropout", None) is not None:
+            with tf.name_scope(self.dropout.name):
+                self.dropout.build(None)
+
+    def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor:
+        num_patches = embeddings.shape[1] - 2
+        num_positions = self.position_embeddings.shape[1] - 2
+
+        if num_patches == num_positions and height == width:
+            return self.position_embeddings
+
+        class_pos_embed = self.position_embeddings[:, 0, :]
+        dist_pos_embed = self.position_embeddings[:, 1, :]
+        patch_pos_embed = self.position_embeddings[:, 2:, :]
+        dim = embeddings.shape[-1]
+        h0 = height // self.config.patch_size
+        w0 = width // self.config.patch_size
+        # # we add a small number to avoid floating point error in the interpolation
+        # # see discussion at https://github.com/facebookresearch/dino/issues/8
+        h0, w0 = h0 + 0.1, w0 + 0.1
+        patch_pos_embed = tf.reshape(
+            patch_pos_embed, (1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+        )
+        patch_pos_embed = tf.image.resize(patch_pos_embed, size=(int(h0), int(w0)), method="bicubic")
+        patch_pos_embed = tf.transpose(patch_pos_embed, perm=[0, 2, 3, 1])
+        patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, dim))
+
+        return tf.concat(
+            [tf.expand_dims(class_pos_embed, axis=0), tf.expand_dims(dist_pos_embed, axis=0), patch_pos_embed], axis=1
+        )
+
+    def call(
+        self,
+        pixel_values: tf.Tensor,
+        bool_masked_pos: tf.Tensor | None = None,
+        training: bool = False,
+        interpolate_pos_encoding: bool = False,
+    ) -> tf.Tensor:
+        _, height, width, _ = pixel_values.shape
+
+        embeddings = self.patch_embeddings(pixel_values)
+        batch_size, seq_length, _ = shape_list(embeddings)
+
+        if bool_masked_pos is not None:
+            mask_tokens = tf.tile(self.mask_token, [batch_size, seq_length, 1])
+            # replace the masked visual tokens by mask_tokens
+            mask = tf.expand_dims(bool_masked_pos, axis=-1)
+            mask = tf.cast(mask, dtype=mask_tokens.dtype)
+            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+        cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
+        distillation_tokens = tf.repeat(self.distillation_token, repeats=batch_size, axis=0)
+        embeddings = tf.concat((cls_tokens, distillation_tokens, embeddings), axis=1)
+        position_embedding = self.position_embeddings
+        if interpolate_pos_encoding:
+            position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
+
+        embeddings = embeddings + position_embedding
+        embeddings = self.dropout(embeddings, training=training)
+        return embeddings
+
+
+class TFDeiTPatchEmbeddings(keras.layers.Layer):
+    """
+    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+    Transformer.
+    """
+
+    def __init__(self, config: DeiTConfig, **kwargs) -> None:
+        super().__init__(**kwargs)
+        image_size, patch_size = config.image_size, config.patch_size
+        num_channels, hidden_size = config.num_channels, config.hidden_size
+
+        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.num_patches = num_patches
+
+        self.projection = keras.layers.Conv2D(
+            hidden_size, kernel_size=patch_size, strides=patch_size, name="projection"
+        )
+
+    def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
+        batch_size, height, width, num_channels = shape_list(pixel_values)
+        if tf.executing_eagerly() and num_channels != self.num_channels:
+            raise ValueError(
+                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+            )
+
+        x = self.projection(pixel_values)
+        batch_size, height, width, num_channels = shape_list(x)
+        x = tf.reshape(x, (batch_size, height * width, num_channels))
+        return x
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "projection", None) is not None:
+            with tf.name_scope(self.projection.name):
+                self.projection.build([None, None, None, self.num_channels])
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfAttention with ViT->DeiT
+class TFDeiTSelfAttention(keras.layers.Layer):
+    def __init__(self, config: DeiTConfig, **kwargs):
+        super().__init__(**kwargs)
+
+        if config.hidden_size % config.num_attention_heads != 0:
+            raise ValueError(
+                f"The hidden size ({config.hidden_size}) is not a multiple of the number "
+                f"of attention heads ({config.num_attention_heads})"
+            )
+
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+        self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
+
+        self.query = keras.layers.Dense(
+            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
+        )
+        self.key = keras.layers.Dense(
+            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
+        )
+        self.value = keras.layers.Dense(
+            units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
+        )
+        self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
+        self.config = config
+
+    def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
+        # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
+        tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
+
+        # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
+        return tf.transpose(tensor, perm=[0, 2, 1, 3])
+
+    def call(
+        self,
+        hidden_states: tf.Tensor,
+        head_mask: tf.Tensor,
+        output_attentions: bool,
+        training: bool = False,
+    ) -> tuple[tf.Tensor]:
+        batch_size = shape_list(hidden_states)[0]
+        mixed_query_layer = self.query(inputs=hidden_states)
+        mixed_key_layer = self.key(inputs=hidden_states)
+        mixed_value_layer = self.value(inputs=hidden_states)
+        query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
+        key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
+        value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        # (batch size, num_heads, seq_len_q, seq_len_k)
+        attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+        dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
+        attention_scores = tf.divide(attention_scores, dk)
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = stable_softmax(logits=attention_scores, axis=-1)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs = self.dropout(inputs=attention_probs, training=training)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs = tf.multiply(attention_probs, head_mask)
+
+        attention_output = tf.matmul(attention_probs, value_layer)
+        attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
+
+        # (batch_size, seq_len_q, all_head_size)
+        attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
+        outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
+
+        return outputs
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "query", None) is not None:
+            with tf.name_scope(self.query.name):
+                self.query.build([None, None, self.config.hidden_size])
+        if getattr(self, "key", None) is not None:
+            with tf.name_scope(self.key.name):
+                self.key.build([None, None, self.config.hidden_size])
+        if getattr(self, "value", None) is not None:
+            with tf.name_scope(self.value.name):
+                self.value.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfOutput with ViT->DeiT
+class TFDeiTSelfOutput(keras.layers.Layer):
+    """
+    The residual connection is defined in TFDeiTLayer instead of here (as is the case with other models), due to the
+    layernorm applied before each block.
+    """
+
+    def __init__(self, config: DeiTConfig, **kwargs):
+        super().__init__(**kwargs)
+
+        self.dense = keras.layers.Dense(
+            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+        )
+        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+        self.config = config
+
+    def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+        hidden_states = self.dense(inputs=hidden_states)
+        hidden_states = self.dropout(inputs=hidden_states, training=training)
+
+        return hidden_states
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "dense", None) is not None:
+            with tf.name_scope(self.dense.name):
+                self.dense.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTAttention with ViT->DeiT
+class TFDeiTAttention(keras.layers.Layer):
+    def __init__(self, config: DeiTConfig, **kwargs):
+        super().__init__(**kwargs)
+
+        self.self_attention = TFDeiTSelfAttention(config, name="attention")
+        self.dense_output = TFDeiTSelfOutput(config, name="output")
+
+    def prune_heads(self, heads):
+        raise NotImplementedError
+
+    def call(
+        self,
+        input_tensor: tf.Tensor,
+        head_mask: tf.Tensor,
+        output_attentions: bool,
+        training: bool = False,
+    ) -> tuple[tf.Tensor]:
+        self_outputs = self.self_attention(
+            hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training
+        )
+        attention_output = self.dense_output(
+            hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
+        )
+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
+
+        return outputs
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "self_attention", None) is not None:
+            with tf.name_scope(self.self_attention.name):
+                self.self_attention.build(None)
+        if getattr(self, "dense_output", None) is not None:
+            with tf.name_scope(self.dense_output.name):
+                self.dense_output.build(None)
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->DeiT
+class TFDeiTIntermediate(keras.layers.Layer):
+    def __init__(self, config: DeiTConfig, **kwargs):
+        super().__init__(**kwargs)
+
+        self.dense = keras.layers.Dense(
+            units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+        )
+
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = get_tf_activation(config.hidden_act)
+        else:
+            self.intermediate_act_fn = config.hidden_act
+        self.config = config
+
+    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+        hidden_states = self.dense(inputs=hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+
+        return hidden_states
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "dense", None) is not None:
+            with tf.name_scope(self.dense.name):
+                self.dense.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTOutput with ViT->DeiT
+class TFDeiTOutput(keras.layers.Layer):
+    def __init__(self, config: DeiTConfig, **kwargs):
+        super().__init__(**kwargs)
+
+        self.dense = keras.layers.Dense(
+            units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+        )
+        self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+        self.config = config
+
+    def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+        hidden_states = self.dense(inputs=hidden_states)
+        hidden_states = self.dropout(inputs=hidden_states, training=training)
+        hidden_states = hidden_states + input_tensor
+
+        return hidden_states
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "dense", None) is not None:
+            with tf.name_scope(self.dense.name):
+                self.dense.build([None, None, self.config.intermediate_size])
+
+
+class TFDeiTLayer(keras.layers.Layer):
+    """This corresponds to the Block class in the timm implementation."""
+
+    def __init__(self, config: DeiTConfig, **kwargs):
+        super().__init__(**kwargs)
+
+        self.attention = TFDeiTAttention(config, name="attention")
+        self.intermediate = TFDeiTIntermediate(config, name="intermediate")
+        self.deit_output = TFDeiTOutput(config, name="output")
+
+        self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
+        self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
+        self.config = config
+
+    def call(
+        self,
+        hidden_states: tf.Tensor,
+        head_mask: tf.Tensor,
+        output_attentions: bool,
+        training: bool = False,
+    ) -> tuple[tf.Tensor]:
+        attention_outputs = self.attention(
+            # in DeiT, layernorm is applied before self-attention
+            input_tensor=self.layernorm_before(inputs=hidden_states, training=training),
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            training=training,
+        )
+        attention_output = attention_outputs[0]
+
+        # first residual connection
+        hidden_states = attention_output + hidden_states
+
+        # in DeiT, layernorm is also applied after self-attention
+        layer_output = self.layernorm_after(inputs=hidden_states, training=training)
+
+        intermediate_output = self.intermediate(hidden_states=layer_output, training=training)
+
+        # second residual connection is done here
+        layer_output = self.deit_output(
+            hidden_states=intermediate_output, input_tensor=hidden_states, training=training
+        )
+        outputs = (layer_output,) + attention_outputs[1:]  # add attentions if we output them
+
+        return outputs
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "attention", None) is not None:
+            with tf.name_scope(self.attention.name):
+                self.attention.build(None)
+        if getattr(self, "intermediate", None) is not None:
+            with tf.name_scope(self.intermediate.name):
+                self.intermediate.build(None)
+        if getattr(self, "deit_output", None) is not None:
+            with tf.name_scope(self.deit_output.name):
+                self.deit_output.build(None)
+        if getattr(self, "layernorm_before", None) is not None:
+            with tf.name_scope(self.layernorm_before.name):
+                self.layernorm_before.build([None, None, self.config.hidden_size])
+        if getattr(self, "layernorm_after", None) is not None:
+            with tf.name_scope(self.layernorm_after.name):
+                self.layernorm_after.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTEncoder with ViT->DeiT
+class TFDeiTEncoder(keras.layers.Layer):
+    def __init__(self, config: DeiTConfig, **kwargs):
+        super().__init__(**kwargs)
+
+        self.layer = [TFDeiTLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
+
+    def call(
+        self,
+        hidden_states: tf.Tensor,
+        head_mask: tf.Tensor,
+        output_attentions: bool,
+        output_hidden_states: bool,
+        return_dict: bool,
+        training: bool = False,
+    ) -> TFBaseModelOutput | tuple[tf.Tensor]:
+        all_hidden_states = () if output_hidden_states else None
+        all_attentions = () if output_attentions else None
+
+        for i, layer_module in enumerate(self.layer):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            layer_outputs = layer_module(
+                hidden_states=hidden_states,
+                head_mask=head_mask[i],
+                output_attentions=output_attentions,
+                training=training,
+            )
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_attentions = all_attentions + (layer_outputs[1],)
+
+        # Add last layer
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
+
+        return TFBaseModelOutput(
+            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "layer", None) is not None:
+            for layer in self.layer:
+                with tf.name_scope(layer.name):
+                    layer.build(None)
+
+
+@keras_serializable
+class TFDeiTMainLayer(keras.layers.Layer):
+    config_class = DeiTConfig
+
+    def __init__(
+        self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
+    ) -> None:
+        super().__init__(**kwargs)
+        self.config = config
+
+        self.embeddings = TFDeiTEmbeddings(config, use_mask_token=use_mask_token, name="embeddings")
+        self.encoder = TFDeiTEncoder(config, name="encoder")
+
+        self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+        self.pooler = TFDeiTPooler(config, name="pooler") if add_pooling_layer else None
+
+    def get_input_embeddings(self) -> TFDeiTPatchEmbeddings:
+        return self.embeddings.patch_embeddings
+
+    def _prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        raise NotImplementedError
+
+    def get_head_mask(self, head_mask):
+        if head_mask is not None:
+            raise NotImplementedError
+        else:
+            head_mask = [None] * self.config.num_hidden_layers
+
+        return head_mask
+
+    @unpack_inputs
+    def call(
+        self,
+        pixel_values: tf.Tensor | None = None,
+        bool_masked_pos: tf.Tensor | None = None,
+        head_mask: tf.Tensor | None = None,
+        output_attentions: bool | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        interpolate_pos_encoding: bool = False,
+        training: bool = False,
+    ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor, ...]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        # TF 2.0 image layers can't use NCHW format when running on CPU.
+        # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
+        pixel_values = tf.transpose(pixel_values, (0, 2, 3, 1))
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask)
+
+        embedding_output = self.embeddings(
+            pixel_values,
+            bool_masked_pos=bool_masked_pos,
+            training=training,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+        )
+
+        encoder_outputs = self.encoder(
+            embedding_output,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+        sequence_output = encoder_outputs[0]
+        sequence_output = self.layernorm(sequence_output, training=training)
+        pooled_output = self.pooler(sequence_output, training=training) if self.pooler is not None else None
+
+        if not return_dict:
+            head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+            return head_outputs + encoder_outputs[1:]
+
+        return TFBaseModelOutputWithPooling(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "embeddings", None) is not None:
+            with tf.name_scope(self.embeddings.name):
+                self.embeddings.build(None)
+        if getattr(self, "encoder", None) is not None:
+            with tf.name_scope(self.encoder.name):
+                self.encoder.build(None)
+        if getattr(self, "layernorm", None) is not None:
+            with tf.name_scope(self.layernorm.name):
+                self.layernorm.build([None, None, self.config.hidden_size])
+        if getattr(self, "pooler", None) is not None:
+            with tf.name_scope(self.pooler.name):
+                self.pooler.build(None)
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTPreTrainedModel with ViT->DeiT all-casing
+class TFDeiTPreTrainedModel(TFPreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = DeiTConfig
+    base_model_prefix = "deit"
+    main_input_name = "pixel_values"
+
+
+DEIT_START_DOCSTRING = r"""
+    This model is a TensorFlow
+    [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer). Use it as a regular
+    TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and behavior.
+
+    Parameters:
+        config ([`DeiTConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DEIT_INPUTS_DOCSTRING = r"""
+    Args:
+        pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+            [`DeiTImageProcessor.__call__`] for details.
+
+        head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
+            Whether to interpolate the pre-trained position encodings.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+    "The bare DeiT Model transformer outputting raw hidden-states without any specific head on top.",
+    DEIT_START_DOCSTRING,
+)
+class TFDeiTModel(TFDeiTPreTrainedModel):
+    def __init__(
+        self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
+    ) -> None:
+        super().__init__(config, **kwargs)
+
+        self.deit = TFDeiTMainLayer(
+            config, add_pooling_layer=add_pooling_layer, use_mask_token=use_mask_token, name="deit"
+        )
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=TFBaseModelOutputWithPooling,
+        config_class=_CONFIG_FOR_DOC,
+        modality="vision",
+        expected_output=_EXPECTED_OUTPUT_SHAPE,
+    )
+    def call(
+        self,
+        pixel_values: tf.Tensor | None = None,
+        bool_masked_pos: tf.Tensor | None = None,
+        head_mask: tf.Tensor | None = None,
+        output_attentions: bool | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        interpolate_pos_encoding: bool = False,
+        training: bool = False,
+    ) -> tuple | TFBaseModelOutputWithPooling:
+        outputs = self.deit(
+            pixel_values=pixel_values,
+            bool_masked_pos=bool_masked_pos,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+            training=training,
+        )
+        return outputs
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "deit", None) is not None:
+            with tf.name_scope(self.deit.name):
+                self.deit.build(None)
+
+
+# Copied from transformers.models.vit.modeling_tf_vit.TFViTPooler with ViT->DeiT
+class TFDeiTPooler(keras.layers.Layer):
+    def __init__(self, config: DeiTConfig, **kwargs):
+        super().__init__(**kwargs)
+
+        self.dense = keras.layers.Dense(
+            units=config.pooler_output_size,
+            kernel_initializer=get_initializer(config.initializer_range),
+            activation=config.pooler_act,
+            name="dense",
+        )
+        self.config = config
+
+    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(inputs=first_token_tensor)
+
+        return pooled_output
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "dense", None) is not None:
+            with tf.name_scope(self.dense.name):
+                self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFDeitPixelShuffle(keras.layers.Layer):
+    """TF layer implementation of torch.nn.PixelShuffle"""
+
+    def __init__(self, upscale_factor: int, **kwargs) -> None:
+        super().__init__(**kwargs)
+        if not isinstance(upscale_factor, int) or upscale_factor < 2:
+            raise ValueError(f"upscale_factor must be an integer value >= 2 got {upscale_factor}")
+        self.upscale_factor = upscale_factor
+
+    def call(self, x: tf.Tensor) -> tf.Tensor:
+        hidden_states = x
+        batch_size, _, _, num_input_channels = shape_list(hidden_states)
+        block_size_squared = self.upscale_factor**2
+        output_depth = int(num_input_channels / block_size_squared)
+        # When the number of output channels >= 2, PyTorch's PixelShuffle and
+        # TF's depth_to_space differ in their output as the order of channels selected for combining
+        # is a permutation of the other c.f.
+        # https://stackoverflow.com/questions/68272502/tf-depth-to-space-not-same-as-torchs-pixelshuffle-when-output-channels-1
+        permutation = tf.constant(
+            [[i + j * block_size_squared for i in range(block_size_squared) for j in range(output_depth)]]
+        )
+        hidden_states = tf.gather(params=hidden_states, indices=tf.tile(permutation, [batch_size, 1]), batch_dims=-1)
+        hidden_states = tf.nn.depth_to_space(hidden_states, block_size=self.upscale_factor, data_format="NHWC")
+        return hidden_states
+
+
+class TFDeitDecoder(keras.layers.Layer):
+    def __init__(self, config: DeiTConfig, **kwargs) -> None:
+        super().__init__(**kwargs)
+        self.conv2d = keras.layers.Conv2D(
+            filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, name="0"
+        )
+        self.pixel_shuffle = TFDeitPixelShuffle(config.encoder_stride, name="1")
+        self.config = config
+
+    def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
+        hidden_states = inputs
+        hidden_states = self.conv2d(hidden_states)
+        hidden_states = self.pixel_shuffle(hidden_states)
+        return hidden_states
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "conv2d", None) is not None:
+            with tf.name_scope(self.conv2d.name):
+                self.conv2d.build([None, None, None, self.config.hidden_size])
+        if getattr(self, "pixel_shuffle", None) is not None:
+            with tf.name_scope(self.pixel_shuffle.name):
+                self.pixel_shuffle.build(None)
+
+
+@add_start_docstrings(
+    "DeiT Model with a decoder on top for masked image modeling, as proposed in"
+    " [SimMIM](https://huggingface.co/papers/2111.09886).",
+    DEIT_START_DOCSTRING,
+)
+class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
+    def __init__(self, config: DeiTConfig) -> None:
+        super().__init__(config)
+
+        self.deit = TFDeiTMainLayer(config, add_pooling_layer=False, use_mask_token=True, name="deit")
+        self.decoder = TFDeitDecoder(config, name="decoder")
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=TFMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
+    def call(
+        self,
+        pixel_values: tf.Tensor | None = None,
+        bool_masked_pos: tf.Tensor | None = None,
+        head_mask: tf.Tensor | None = None,
+        output_attentions: bool | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        interpolate_pos_encoding: bool = False,
+        training: bool = False,
+    ) -> tuple | TFMaskedImageModelingOutput:
+        r"""
+        bool_masked_pos (`tf.Tensor` of type bool and shape `(batch_size, num_patches)`):
+            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+
+        Returns:
+
+        Examples:
+        ```python
+        >>> from transformers import AutoImageProcessor, TFDeiTForMaskedImageModeling
+        >>> import tensorflow as tf
+        >>> from PIL import Image
+        >>> import requests
+
+        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw)
+
+        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
+        >>> model = TFDeiTForMaskedImageModeling.from_pretrained("facebook/deit-base-distilled-patch16-224")
+
+        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
+        >>> pixel_values = image_processor(images=image, return_tensors="tf").pixel_values
+        >>> # create random boolean mask of shape (batch_size, num_patches)
+        >>> bool_masked_pos = tf.cast(tf.random.uniform((1, num_patches), minval=0, maxval=2, dtype=tf.int32), tf.bool)
+
+        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
+        >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
+        >>> list(reconstructed_pixel_values.shape)
+        [1, 3, 224, 224]
+        ```"""
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.deit(
+            pixel_values,
+            bool_masked_pos=bool_masked_pos,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+            training=training,
+        )
+
+        sequence_output = outputs[0]
+
+        # Reshape to (batch_size, num_channels, height, width)
+        sequence_output = sequence_output[:, 1:-1]
+        batch_size, sequence_length, num_channels = shape_list(sequence_output)
+        height = width = int(sequence_length**0.5)
+        sequence_output = tf.reshape(sequence_output, (batch_size, height, width, num_channels))
+
+        # Reconstruct pixel values
+        reconstructed_pixel_values = self.decoder(sequence_output, training=training)
+        # TF 2.0 image layers can't use NCHW format when running on CPU, so intermediate layers use NHWC,
+        # including the decoder. We transpose to compute the loss against the pixel values
+        # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width)
+        reconstructed_pixel_values = tf.transpose(reconstructed_pixel_values, (0, 3, 1, 2))
+
+        masked_im_loss = None
+        if bool_masked_pos is not None:
+            size = self.config.image_size // self.config.patch_size
+            bool_masked_pos = tf.reshape(bool_masked_pos, (-1, size, size))
+            mask = tf.repeat(bool_masked_pos, self.config.patch_size, 1)
+            mask = tf.repeat(mask, self.config.patch_size, 2)
+            mask = tf.expand_dims(mask, 1)
+            mask = tf.cast(mask, tf.float32)
+
+            reconstruction_loss = keras.losses.mean_absolute_error(
+                # Swap axes as metric calculation reduces over the final dimension
+                tf.transpose(pixel_values, (1, 2, 3, 0)),
+                tf.transpose(reconstructed_pixel_values, (1, 2, 3, 0)),
+            )
+            reconstruction_loss = tf.expand_dims(reconstruction_loss, 0)
+            total_loss = tf.reduce_sum(reconstruction_loss * mask)
+            num_masked_pixels = (tf.reduce_sum(mask) + 1e-5) * self.config.num_channels
+            masked_im_loss = total_loss / num_masked_pixels
+            masked_im_loss = tf.reshape(masked_im_loss, (1,))
+
+        if not return_dict:
+            output = (reconstructed_pixel_values,) + outputs[1:]
+            return ((masked_im_loss,) + output) if masked_im_loss is not None else output
+
+        return TFMaskedImageModelingOutput(
+            loss=masked_im_loss,
+            reconstruction=reconstructed_pixel_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "deit", None) is not None:
+            with tf.name_scope(self.deit.name):
+                self.deit.build(None)
+        if getattr(self, "decoder", None) is not None:
+            with tf.name_scope(self.decoder.name):
+                self.decoder.build(None)
+
+
+@add_start_docstrings(
+    """
+    DeiT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
+    the [CLS] token) e.g. for ImageNet.
+    """,
+    DEIT_START_DOCSTRING,
+)
+class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificationLoss):
+    def __init__(self, config: DeiTConfig):
+        super().__init__(config)
+
+        self.num_labels = config.num_labels
+        self.deit = TFDeiTMainLayer(config, add_pooling_layer=False, name="deit")
+
+        # Classifier head
+        self.classifier = (
+            keras.layers.Dense(config.num_labels, name="classifier")
+            if config.num_labels > 0
+            else keras.layers.Activation("linear", name="classifier")
+        )
+        self.config = config
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=TFImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
+    def call(
+        self,
+        pixel_values: tf.Tensor | None = None,
+        head_mask: tf.Tensor | None = None,
+        labels: tf.Tensor | None = None,
+        output_attentions: bool | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        interpolate_pos_encoding: bool = False,
+        training: bool = False,
+    ) -> tf.Tensor | TFImageClassifierOutput:
+        r"""
+        labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+        Returns:
+
+        Examples:
+
+        ```python
+        >>> from transformers import AutoImageProcessor, TFDeiTForImageClassification
+        >>> import tensorflow as tf
+        >>> from PIL import Image
+        >>> import requests
+
+        >>> keras.utils.set_random_seed(3)  # doctest: +IGNORE_RESULT
+        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw)
+
+        >>> # note: we are loading a TFDeiTForImageClassificationWithTeacher from the hub here,
+        >>> # so the head will be randomly initialized, hence the predictions will be random
+        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
+        >>> model = TFDeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")
+
+        >>> inputs = image_processor(images=image, return_tensors="tf")
+        >>> outputs = model(**inputs)
+        >>> logits = outputs.logits
+        >>> # model predicts one of the 1000 ImageNet classes
+        >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
+        >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
+        Predicted class: little blue heron, Egretta caerulea
+        ```"""
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.deit(
+            pixel_values,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+            training=training,
+        )
+
+        sequence_output = outputs[0]
+
+        logits = self.classifier(sequence_output[:, 0, :])
+        # we don't use the distillation token
+
+        loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return TFImageClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "deit", None) is not None:
+            with tf.name_scope(self.deit.name):
+                self.deit.build(None)
+        if getattr(self, "classifier", None) is not None:
+            with tf.name_scope(self.classifier.name):
+                self.classifier.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+    """
+    DeiT Model transformer with image classification heads on top (a linear layer on top of the final hidden state of
+    the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.
+
+    .. warning::
+
+            This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
+            supported.
+    """,
+    DEIT_START_DOCSTRING,
+)
+class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel):
+    def __init__(self, config: DeiTConfig) -> None:
+        super().__init__(config)
+
+        self.num_labels = config.num_labels
+        self.deit = TFDeiTMainLayer(config, add_pooling_layer=False, name="deit")
+
+        # Classifier heads
+        self.cls_classifier = (
+            keras.layers.Dense(config.num_labels, name="cls_classifier")
+            if config.num_labels > 0
+            else keras.layers.Activation("linear", name="cls_classifier")
+        )
+        self.distillation_classifier = (
+            keras.layers.Dense(config.num_labels, name="distillation_classifier")
+            if config.num_labels > 0
+            else keras.layers.Activation("linear", name="distillation_classifier")
+        )
+        self.config = config
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_IMAGE_CLASS_CHECKPOINT,
+        output_type=TFDeiTForImageClassificationWithTeacherOutput,
+        config_class=_CONFIG_FOR_DOC,
+        expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+    )
+    def call(
+        self,
+        pixel_values: tf.Tensor | None = None,
+        head_mask: tf.Tensor | None = None,
+        output_attentions: bool | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        interpolate_pos_encoding: bool = False,
+        training: bool = False,
+    ) -> tuple | TFDeiTForImageClassificationWithTeacherOutput:
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.deit(
+            pixel_values,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            interpolate_pos_encoding=interpolate_pos_encoding,
+            training=training,
+        )
+
+        sequence_output = outputs[0]
+
+        cls_logits = self.cls_classifier(sequence_output[:, 0, :])
+        distillation_logits = self.distillation_classifier(sequence_output[:, 1, :])
+
+        # during inference, return the average of both classifier predictions
+        logits = (cls_logits + distillation_logits) / 2
+
+        if not return_dict:
+            output = (logits, cls_logits, distillation_logits) + outputs[1:]
+            return output
+
+        return TFDeiTForImageClassificationWithTeacherOutput(
+            logits=logits,
+            cls_logits=cls_logits,
+            distillation_logits=distillation_logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "deit", None) is not None:
+            with tf.name_scope(self.deit.name):
+                self.deit.build(None)
+        if getattr(self, "cls_classifier", None) is not None:
+            with tf.name_scope(self.cls_classifier.name):
+                self.cls_classifier.build([None, None, self.config.hidden_size])
+        if getattr(self, "distillation_classifier", None) is not None:
+            with tf.name_scope(self.distillation_classifier.name):
+                self.distillation_classifier.build([None, None, self.config.hidden_size])
+
+
+__all__ = [
+    "TFDeiTForImageClassification",
+    "TFDeiTForImageClassificationWithTeacher",
+    "TFDeiTForMaskedImageModeling",
+    "TFDeiTModel",
+    "TFDeiTPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deprecated/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deprecated/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e293c354e1e92a431a601da77d7555f2ecfe29ef
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/deprecated/__init__.py
@@ -0,0 +1,49 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .bort import *
+    from .deta import *
+    from .efficientformer import *
+    from .ernie_m import *
+    from .gptsan_japanese import *
+    from .graphormer import *
+    from .jukebox import *
+    from .mctct import *
+    from .mega import *
+    from .mmbt import *
+    from .nat import *
+    from .nezha import *
+    from .open_llama import *
+    from .qdqbert import *
+    from .realm import *
+    from .retribert import *
+    from .speech_to_text_2 import *
+    from .tapex import *
+    from .trajectory_transformer import *
+    from .transfo_xl import *
+    from .tvlt import *
+    from .van import *
+    from .vit_hybrid import *
+    from .xlm_prophetnet import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d738fbc087888597da19735271366d4e35ab708c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_dia import *
+    from .feature_extraction_dia import *
+    from .generation_dia import *
+    from .modeling_dia import *
+    from .processing_dia import *
+    from .tokenization_dia import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/configuration_dia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/configuration_dia.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4dec60b3e4853574e4d528e7b641507a8c0b414
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/configuration_dia.py
@@ -0,0 +1,376 @@
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Dia model configuration"""
+
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DiaEncoderConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`DiaEncoder`]. It is used to instantiate a Dia
+    encoder according to the specified arguments, defining the encoder architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        max_position_embeddings (`int`, *optional*, defaults to 1024):
+            The maximum sequence length that this model might ever be used with.
+        num_hidden_layers (`int`, *optional*, defaults to 12):
+            Number of hidden layers in the Transformer encoder.
+        hidden_size (`int`, *optional*, defaults to 1024):
+            Dimensionality of the encoder layers and the pooler layer.
+        num_attention_heads (`int`, *optional*, defaults to 16):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        num_key_value_heads (`int`, *optional*, defaults to 16):
+            Number of key and value heads for each attention layer in the Transformer encoder.
+        head_dim (`int`, *optional*, defaults to 128):
+            Dimensionality of the attention head.
+        intermediate_size (`int`, *optional*, defaults to 4096):
+            Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+        norm_eps (`float`, *optional*, defaults to 1e-05):
+            The epsilon used by the normalization layers.
+        vocab_size (`int`, *optional*, defaults to 256):
+            Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`DiaModel`].
+        hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"swish"` and `"gelu_new"` are supported.
+        rope_theta (`float`, *optional*, defaults to 10000.0):
+            The base period of the RoPE embeddings.
+        rope_scaling (`dict`, *optional*):
+            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+            accordingly.
+            Expected contents:
+                `rope_type` (`str`):
+                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+                    'llama3'], with 'default' being the original RoPE implementation.
+                `factor` (`float`, *optional*):
+                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+                    original maximum pre-trained length.
+                `original_max_position_embeddings` (`int`, *optional*):
+                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+                    pretraining.
+                `attention_factor` (`float`, *optional*):
+                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+                    computation. If unspecified, it defaults to value recommended by the implementation, using the
+                    `factor` field to infer the suggested value.
+                `beta_fast` (`float`, *optional*):
+                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+                    ramp function. If unspecified, it defaults to 32.
+                `beta_slow` (`float`, *optional*):
+                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+                    ramp function. If unspecified, it defaults to 1.
+                `short_factor` (`List[float]`, *optional*):
+                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+                    size divided by the number of attention heads divided by 2
+                `long_factor` (`List[float]`, *optional*):
+                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+                    size divided by the number of attention heads divided by 2
+                `low_freq_factor` (`float`, *optional*):
+                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+                `high_freq_factor` (`float`, *optional*):
+                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+    """
+
+    model_type = "dia_encoder"
+
+    def __init__(
+        self,
+        max_position_embeddings: int = 1024,
+        num_hidden_layers: int = 12,
+        hidden_size: int = 1024,
+        num_attention_heads: int = 16,
+        num_key_value_heads: int = 16,
+        head_dim: int = 128,
+        intermediate_size: int = 4096,
+        norm_eps: float = 1e-5,
+        vocab_size: int = 256,
+        hidden_act: str = "silu",
+        rope_theta: float = 10000.0,
+        rope_scaling: Optional[dict] = None,
+        initializer_range: float = 0.02,
+        **kwargs,
+    ):
+        self.max_position_embeddings = max_position_embeddings
+        self.num_hidden_layers = num_hidden_layers
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_attention_heads = num_attention_heads
+        self.head_dim = head_dim
+        self.norm_eps = norm_eps
+        self.vocab_size = vocab_size
+        self.num_key_value_heads = num_key_value_heads
+        self.hidden_act = hidden_act
+        self.rope_theta = rope_theta
+        self.rope_scaling = rope_scaling
+        # Validate the correctness of rotary position embeddings parameters
+        # BC: if there is a 'type' field, copy it it to 'rope_type'.
+        if self.rope_scaling is not None and "type" in self.rope_scaling:
+            self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+        rope_config_validation(self)
+        self.initializer_range = initializer_range
+        super().__init__(**kwargs)
+
+
+class DiaDecoderConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`DiaDecoder`]. It is used to instantiate a Dia
+    decoder according to the specified arguments, defining the decoder architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        max_position_embeddings (`int`, *optional*, defaults to 3072):
+            The maximum sequence length that this model might ever be used with.
+        num_hidden_layers (`int`, *optional*, defaults to 18):
+            Number of hidden layers in the Transformer decoder.
+        hidden_size (`int`, *optional*, defaults to 2048):
+            Dimensionality of the decoder layers and the pooler layer.
+        intermediate_size (`int`, *optional*, defaults to 8192):
+            Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer decoder.
+        num_attention_heads (`int`, *optional*, defaults to 16):
+            Number of attention heads for each attention layer in the Transformer decoder.
+        num_key_value_heads (`int`, *optional*, defaults to 4):
+            Number of key and value heads for each attention layer in the Transformer decoder.
+        head_dim (`int`, *optional*, defaults to 128):
+            Dimensionality of the attention head.
+        cross_num_attention_heads (`int`, *optional*, defaults to 16):
+            Number of attention heads for each cross-attention layer in the Transformer decoder.
+        cross_head_dim (`int`, *optional*, defaults to 128):
+            Dimensionality of the cross-attention head.
+        cross_num_key_value_heads (`int`, *optional*, defaults to 16):
+            Number of key and value heads for each cross-attention layer in the Transformer decoder.
+        cross_hidden_size (`int`, *optional*, defaults to 1024):
+            Dimensionality of the cross-attention layers.
+        norm_eps (`float`, *optional*, defaults to 1e-05):
+            The epsilon used by the normalization layers.
+        vocab_size (`int`, *optional*, defaults to 1028):
+            Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`DiaModel`].
+        hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+            The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, `"relu"`,
+            `"swish"` and `"gelu_new"` are supported.
+        num_channels (`int`, *optional*, defaults to 9):
+            Number of channels for the Dia decoder.
+        rope_theta (`float`, *optional*, defaults to 10000.0):
+            The base period of the RoPE embeddings.
+        rope_scaling (`dict`, *optional*):
+            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+            accordingly.
+            Expected contents:
+                `rope_type` (`str`):
+                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+                    'llama3'], with 'default' being the original RoPE implementation.
+                `factor` (`float`, *optional*):
+                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+                    original maximum pre-trained length.
+                `original_max_position_embeddings` (`int`, *optional*):
+                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+                    pretraining.
+                `attention_factor` (`float`, *optional*):
+                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+                    computation. If unspecified, it defaults to value recommended by the implementation, using the
+                    `factor` field to infer the suggested value.
+                `beta_fast` (`float`, *optional*):
+                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+                    ramp function. If unspecified, it defaults to 32.
+                `beta_slow` (`float`, *optional*):
+                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+                    ramp function. If unspecified, it defaults to 1.
+                `short_factor` (`List[float]`, *optional*):
+                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+                    size divided by the number of attention heads divided by 2
+                `long_factor` (`List[float]`, *optional*):
+                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+                    size divided by the number of attention heads divided by 2
+                `low_freq_factor` (`float`, *optional*):
+                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+                `high_freq_factor` (`float`, *optional*):
+                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models).
+        is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+            Indicating that this model is part of an encoder-decoder architecture.
+    """
+
+    model_type = "dia_decoder"
+
+    def __init__(
+        self,
+        max_position_embeddings: int = 3072,
+        num_hidden_layers: int = 18,
+        hidden_size: int = 2048,
+        intermediate_size: int = 8192,
+        num_attention_heads: int = 16,
+        num_key_value_heads: int = 4,
+        head_dim: int = 128,
+        cross_num_attention_heads: int = 16,
+        cross_head_dim: int = 128,
+        cross_num_key_value_heads: int = 16,
+        cross_hidden_size: int = 1024,
+        norm_eps: float = 1e-5,
+        vocab_size: int = 1028,
+        hidden_act: str = "silu",
+        num_channels: int = 9,
+        rope_theta: float = 10000.0,
+        rope_scaling: Optional[dict] = None,
+        initializer_range: float = 0.02,
+        use_cache: bool = True,
+        is_encoder_decoder: bool = True,
+        **kwargs,
+    ):
+        self.max_position_embeddings = max_position_embeddings
+        self.num_hidden_layers = num_hidden_layers
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_attention_heads = num_attention_heads
+        self.num_key_value_heads = num_key_value_heads
+        self.head_dim = head_dim
+        self.cross_num_key_value_heads = cross_num_key_value_heads
+        self.cross_num_attention_heads = cross_num_attention_heads
+        self.cross_head_dim = cross_head_dim
+        self.cross_hidden_size = cross_hidden_size
+        self.norm_eps = norm_eps
+        self.vocab_size = vocab_size
+        self.hidden_act = hidden_act
+        self.num_channels = num_channels
+        self.rope_theta = rope_theta
+        self.rope_scaling = rope_scaling
+        # Validate the correctness of rotary position embeddings parameters
+        # BC: if there is a 'type' field, copy it it to 'rope_type'.
+        if self.rope_scaling is not None and "type" in self.rope_scaling:
+            self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+        rope_config_validation(self)
+        self.initializer_range = initializer_range
+        self.use_cache = use_cache
+        super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
+
+
+class DiaConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`DiaModel`]. It is used to instantiate a
+    Dia model according to the specified arguments, defining the model architecture. Instantiating a configuration
+    with the defaults will yield a similar configuration to that of the
+    [nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        encoder_config (`DiaEncoderConfig`, *optional*):
+            Configuration for the encoder part of the model. If not provided, a default `DiaEncoderConfig` will be used.
+        decoder_config (`DiaDecoderConfig`, *optional*):
+            Configuration for the decoder part of the model. If not provided, a default `DiaDecoderConfig` will be used.
+        norm_eps (`float`, *optional*, defaults to 1e-05):
+            The epsilon used by the normalization layers.
+        is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+            Indicating that this model uses an encoder-decoder architecture.
+        pad_token_id (`int`, *optional*, defaults to 1025):
+            Padding token id.
+        eos_token_id (`int`, *optional*, defaults to 1024):
+            End of stream token id.
+        bos_token_id (`int`, *optional*, defaults to 1026):
+            Beginning of stream token id.
+        delay_pattern (`list[int]`, *optional*, defaults to `[0, 8, 9, 10, 11, 12, 13, 14, 15]`):
+            The delay pattern for the decoder. The length of this list must match `decoder_config.num_channels`.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models).
+
+    Example:
+
+    ```python
+    >>> from transformers import DiaConfig, DiaModel
+
+    >>> # Initializing a DiaConfig with default values
+    >>> configuration = DiaConfig()
+
+    >>> # Initializing a DiaModel (with random weights) from the configuration
+    >>> model = DiaModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```
+    """
+
+    model_type = "dia"
+    keys_to_ignore_at_inference = ["past_key_values"]
+    sub_configs = {"encoder_config": DiaEncoderConfig, "decoder_config": DiaDecoderConfig}
+
+    def __init__(
+        self,
+        encoder_config: Optional[DiaEncoderConfig] = None,
+        decoder_config: Optional[DiaDecoderConfig] = None,
+        norm_eps: float = 1e-5,
+        is_encoder_decoder: bool = True,
+        pad_token_id: int = 1025,
+        eos_token_id: int = 1024,
+        bos_token_id: int = 1026,
+        delay_pattern: Optional[list[int]] = None,
+        initializer_range: float = 0.02,
+        use_cache: bool = True,
+        **kwargs,
+    ):
+        if isinstance(encoder_config, dict):
+            encoder_config = DiaEncoderConfig(**encoder_config)
+        if isinstance(decoder_config, dict):
+            decoder_config = DiaDecoderConfig(**decoder_config)
+        self.encoder_config = encoder_config if encoder_config is not None else DiaEncoderConfig()
+        self.decoder_config = decoder_config if decoder_config is not None else DiaDecoderConfig()
+        self.norm_eps = norm_eps
+        self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 8, 9, 10, 11, 12, 13, 14, 15]
+        self.initializer_range = initializer_range
+        self.use_cache = use_cache
+
+        assert self.decoder_config.num_channels == len(self.delay_pattern), (
+            "Number of channels must match delay pattern length."
+        )
+
+        super().__init__(
+            pad_token_id=pad_token_id,
+            eos_token_id=eos_token_id,
+            bos_token_id=bos_token_id,
+            is_encoder_decoder=is_encoder_decoder,
+            **kwargs,
+        )
+
+    def get_text_config(self, *args, **kwargs):
+        """Defaulting to audio config as it's the decoder in this case which is usually the text backbone"""
+        return self.decoder_config
+
+
+__all__ = ["DiaConfig", "DiaEncoderConfig", "DiaDecoderConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/feature_extraction_dia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/feature_extraction_dia.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4376b773b27365932774a35746b2928cf0af707
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/feature_extraction_dia.py
@@ -0,0 +1,183 @@
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for Dia"""
+
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
+from ...feature_extraction_utils import BatchFeature
+from ...utils import PaddingStrategy, TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DiaFeatureExtractor(SequenceFeatureExtractor):
+    r"""
+    Constructs an Dia feature extractor.
+
+    This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
+    most of the main methods. Users should refer to this superclass for more information regarding those methods.
+
+    Args:
+        feature_size (`int`, *optional*, defaults to 1):
+            The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
+        sampling_rate (`int`, *optional*, defaults to 16000):
+            The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
+        padding_value (`float`, *optional*, defaults to 0.0):
+            The value that is used for padding.
+        hop_length (`int`, *optional*, defaults to 512):
+            Overlap length between successive windows.
+    """
+
+    model_input_names = ["input_values", "n_quantizers"]
+
+    def __init__(
+        self,
+        feature_size: int = 1,
+        sampling_rate: int = 16000,
+        padding_value: float = 0.0,
+        hop_length: int = 512,
+        **kwargs,
+    ):
+        super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
+        self.hop_length = hop_length
+
+    def __call__(
+        self,
+        raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
+        padding: Optional[Union[bool, str, PaddingStrategy]] = None,
+        truncation: Optional[bool] = False,
+        max_length: Optional[int] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        sampling_rate: Optional[int] = None,
+    ) -> BatchFeature:
+        """
+        Main method to featurize and prepare for the model one or several sequence(s).
+
+        Args:
+            raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
+                The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
+                values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
+                `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
+                (`feature_size = 2`).
+            padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
+                Select a strategy to pad the returned sequences (according to the model's padding side and padding
+                index) among:
+
+                - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+                  sequence if provided).
+                - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+                  acceptable input length for the model if that argument is not provided.
+                - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+                  lengths).
+            truncation (`bool`, *optional*, defaults to `False`):
+                Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+            max_length (`int`, *optional*):
+                Maximum length of the returned list and optionally padding length (see above).
+            return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'):
+                If set, will return tensors instead of list of python integers. Acceptable values are:
+
+                - `'tf'`: Return TensorFlow `tf.constant` objects.
+                - `'pt'`: Return PyTorch `torch.Tensor` objects.
+                - `'np'`: Return Numpy `np.ndarray` objects.
+            sampling_rate (`int`, *optional*):
+                The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
+                `sampling_rate` at the forward call to prevent silent errors.
+        """
+        if sampling_rate is not None:
+            if sampling_rate != self.sampling_rate:
+                raise ValueError(
+                    f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
+                    f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
+                    f" {self.sampling_rate} and not {sampling_rate}."
+                )
+        else:
+            logger.warning(
+                f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
+                "Failing to do so can result in silent errors that might be hard to debug."
+            )
+
+        if padding and truncation:
+            raise ValueError("Both padding and truncation were set. Make sure you only set one.")
+        elif padding is None:
+            # by default let's pad the inputs
+            padding = True
+
+        is_batched = bool(
+            isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
+        )
+
+        if is_batched:
+            raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
+        elif not is_batched and not isinstance(raw_audio, np.ndarray):
+            raw_audio = np.asarray(raw_audio, dtype=np.float32)
+        elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
+            raw_audio = raw_audio.astype(np.float32)
+
+        # always return batch
+        if not is_batched:
+            raw_audio = [np.asarray(raw_audio).T]
+
+        # convert stereo to mono if necessary, unique to Dia
+        for idx, example in enumerate(raw_audio):
+            if self.feature_size == 2 and example.ndim == 2:
+                raw_audio[idx] = np.mean(example, -1)
+
+        # verify inputs are valid
+        for idx, example in enumerate(raw_audio):
+            if example.ndim > 2:
+                raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
+            if self.feature_size == 1 and example.ndim != 1:
+                raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
+            if self.feature_size == 2 and example.ndim != 1:  # note the conversion before
+                raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
+
+        input_values = BatchFeature({"input_values": raw_audio})
+
+        # temporarily treat it as if we were mono as we also convert stereo to mono
+        original_feature_size = self.feature_size
+        self.feature_size = 1
+
+        # normal padding on batch
+        padded_inputs = self.pad(
+            input_values,
+            max_length=max_length,
+            truncation=truncation,
+            padding=padding,
+            return_attention_mask=True,
+            pad_to_multiple_of=self.hop_length,
+        )
+        padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
+
+        input_values = []
+        for example in padded_inputs.pop("input_values"):
+            if self.feature_size == 1:
+                example = example[..., None]
+            input_values.append(example.T)
+
+        padded_inputs["input_values"] = input_values
+        if return_tensors is not None:
+            padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
+
+        # rewrite back to original feature size
+        self.feature_size = original_feature_size
+
+        return padded_inputs
+
+
+__all__ = ["DiaFeatureExtractor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/generation_dia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/generation_dia.py
new file mode 100644
index 0000000000000000000000000000000000000000..c297de7203d4e5b30a189047233976d310179907
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/generation_dia.py
@@ -0,0 +1,463 @@
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+
+from ...generation.logits_process import (
+    DiaClassifierFreeGuidanceLogitsProcessor,
+    DiaEOSChannelFilterLogitsProcessor,
+    DiaEOSDelayPatternLogitsProcessor,
+    LogitsProcessorList,
+    TemperatureLogitsWarper,
+)
+from ...generation.stopping_criteria import StoppingCriteriaList
+from ...generation.streamers import BaseStreamer
+from ...generation.utils import GenerateOutput, GenerationConfig, GenerationMixin, GenerationMode
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...integrations.fsdp import is_fsdp_managed_module
+from ...modeling_utils import PreTrainedModel
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DiaGenerationMixin(GenerationMixin):
+    # Indicates CFG which needs preparation to be properly handled by repeats
+    _uses_cfg = None
+
+    def _get_logits_processor(
+        self,
+        generation_config: GenerationConfig,
+        input_ids_seq_length: Optional[int] = None,
+        encoder_input_ids: Optional[torch.LongTensor] = None,
+        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
+        logits_processor: Optional[LogitsProcessorList] = None,
+        device: Optional[str] = None,
+        model_kwargs: Optional[dict[str, Any]] = None,
+        negative_prompt_ids: Optional[torch.Tensor] = None,
+        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+    ) -> LogitsProcessorList:
+        # Need either custom order or custom processor instead
+        # (Temporarily disabling those for the super function)
+        original_guidance_scale = generation_config.guidance_scale
+        original_temperature = generation_config.temperature
+        generation_config.guidance_scale = None
+        generation_config.temperature = None
+
+        # Get base processors and those we can integrate easily
+        custom_processors = LogitsProcessorList()
+
+        if original_temperature is not None and original_temperature != 1.0:
+            custom_processors.append(TemperatureLogitsWarper(original_temperature))
+
+        custom_processors.append(
+            DiaEOSChannelFilterLogitsProcessor(
+                num_channels=len(self.config.delay_pattern),
+                eos_token_id=self.config.eos_token_id,
+            )
+        )
+
+        merged_processors = super()._get_logits_processor(
+            generation_config=generation_config,
+            input_ids_seq_length=input_ids_seq_length,
+            encoder_input_ids=encoder_input_ids,
+            prefix_allowed_tokens_fn=None,
+            logits_processor=custom_processors,
+            device=device,
+            model_kwargs=model_kwargs,
+            negative_prompt_ids=negative_prompt_ids,
+            negative_prompt_attention_mask=negative_prompt_attention_mask,
+        )
+
+        # Custom processors we need at specific positions
+        if original_guidance_scale is not None and original_guidance_scale != 1:
+            cfg_processor = DiaClassifierFreeGuidanceLogitsProcessor(
+                guidance_scale=original_guidance_scale,
+                guidance_top_k=generation_config.top_k,
+            )
+            merged_processors.insert(0, cfg_processor)
+
+        merged_processors.append(
+            DiaEOSDelayPatternLogitsProcessor(
+                delay_pattern=self.config.delay_pattern,
+                eos_token_id=self.config.eos_token_id,
+                max_generation_len=generation_config.max_length,
+                device=device,
+            )
+        )
+
+        # Enable temporarily disabled values back
+        generation_config.guidance_scale = original_guidance_scale
+        generation_config.temperature = original_temperature
+
+        return merged_processors
+
+    def _prepare_generation_config(
+        self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Any
+    ) -> tuple[GenerationConfig, dict]:
+        generation_config, model_kwargs = super()._prepare_generation_config(
+            generation_config, use_model_defaults, **kwargs
+        )
+
+        # We allow generation up to max length + max delay pattern
+        # (will revert back to max length after generation)
+        generation_config.max_length += max(self.config.delay_pattern)
+
+        # Internal flag to indicate CFG that needs to prepare unconditioned input
+        self._uses_cfg = generation_config.guidance_scale is not None and generation_config.guidance_scale != 1
+
+        return generation_config, model_kwargs
+
+    def _prepare_model_inputs(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        bos_token_id: Optional[torch.Tensor] = None,
+        model_kwargs: Optional[dict[str, torch.Tensor]] = None,
+    ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
+        inputs, input_name, model_kwargs = super()._prepare_model_inputs(
+            inputs=inputs,
+            bos_token_id=bos_token_id,
+            model_kwargs=model_kwargs,
+        )
+
+        # If CFG is requested we fill in the unconditioned parts
+        if self._uses_cfg:
+            unconditioned_inputs = torch.zeros_like(inputs)
+            inputs = torch.cat([inputs, unconditioned_inputs], dim=0)
+
+            if model_kwargs.get("attention_mask", None) is not None:
+                model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat(2, 1)
+
+        return inputs, input_name, model_kwargs
+
+    def _prepare_decoder_input_ids_for_generation(
+        self,
+        batch_size: int,
+        model_input_name: str,
+        model_kwargs: dict[str, torch.Tensor],
+        decoder_start_token_id: torch.Tensor,
+        device: Optional[torch.device] = None,
+    ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
+        """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
+        # 1. Check whether the user has defined `decoder_input_ids` and `decoder_attention_mask`; if not error out
+        decoder_input_ids = decoder_attention_mask = None
+        if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
+            decoder_input_ids = model_kwargs.pop("decoder_input_ids")
+        if model_kwargs is not None and "decoder_attention_mask" in model_kwargs:
+            decoder_attention_mask = model_kwargs.pop("decoder_attention_mask")
+
+        # We allow generating without preparation (no proper delay) but discourage it
+        if decoder_input_ids is None or decoder_attention_mask is None:
+            logger.warning_once(
+                "In order to generate with Dia, we need the processed audio input: Got `decoder_input_ids`:"
+                f" {decoder_input_ids is not None} and got `decoder_attention_mask`={decoder_attention_mask is not None}."
+                f" This can be achieved via the [`DiaProcessor`] but now defaulting to non-delayed generation."
+            )
+
+            num_channels = self.config.decoder_config.num_channels
+            real_batch_size = batch_size // 2 if self._uses_cfg else batch_size
+
+            if decoder_input_ids is None:
+                decoder_input_ids = torch.full(
+                    (real_batch_size, 1, num_channels), decoder_start_token_id, dtype=torch.long, device=device
+                )
+
+            decoder_attention_mask = torch.ones(
+                size=(real_batch_size, decoder_input_ids.shape[1]), dtype=torch.long, device=device
+            )
+
+        # 2. Determine the valid input and what works as mask within the input
+        delay_mask = decoder_input_ids.long()
+        valid_input_size = (
+            decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == self.config.pad_token_id).sum(dim=-1).max()
+        )
+        decoder_input_ids = delay_mask[:, :valid_input_size].transpose(1, 2).long()
+        decoder_attention_mask = decoder_attention_mask[:, :valid_input_size].long()
+
+        # 3. Overwrite into model kwargs
+        model_kwargs["decoder_attention_mask"] = decoder_attention_mask
+        model_kwargs["decoder_delay_mask"] = delay_mask
+
+        return decoder_input_ids, model_kwargs
+
+    def prepare_inputs_for_generation(
+        self,
+        input_ids,
+        encoder_outputs=None,  # Using this to easily get the batch size
+        decoder_delay_mask=None,
+        **kwargs,
+    ):
+        # Reshape decoder input_ids to 3D to be compile friendly and to fit the expected model input shape
+        batch_size = encoder_outputs[0].shape[0] // 2 if self._uses_cfg else encoder_outputs[0].shape[0]
+        input_ids = input_ids.reshape(batch_size, self.config.decoder_config.num_channels, -1).transpose(1, 2)
+
+        # Base method handles most things except CFG and the delay pattern mask
+        model_inputs = super().prepare_inputs_for_generation(input_ids, encoder_outputs=encoder_outputs, **kwargs)
+
+        # Post processing for CFG and overwriting via delay pattern mask
+        # 1. Delay pattern mask -- force tokens if not allowed to predict (!= pad_token in mask)
+        model_inputs["decoder_input_ids"] = self.apply_delay_mask(
+            input_ids, self.config.pad_token_id, decoder_delay_mask
+        )
+
+        # Depending on cache usage we need to pass all or just one
+        if model_inputs.get("use_cache", False) and model_inputs["cache_position"][0] > 0:
+            model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"][:, -1, :][:, None, :]
+
+        # Be compile friendly
+        model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"].contiguous()
+
+        # 2. Apply CFG duplication if needed
+        if self._uses_cfg:
+            for key in ["decoder_input_ids", "decoder_attention_mask", "decoder_position_ids"]:
+                if model_inputs.get(key, None) is not None:
+                    # double first dimension and keep everything else the same
+                    repeat_pattern = tuple([2] + [1] * (model_inputs[key].ndim - 1))
+                    model_inputs[key] = model_inputs[key].repeat(*repeat_pattern)
+
+        return model_inputs
+
+    @staticmethod
+    def apply_delay_mask(input_ids: torch.Tensor, pad_id: int, delay_mask: Optional[torch.Tensor]) -> torch.Tensor:
+        if delay_mask is None:
+            return input_ids
+
+        mask_len = min(input_ids.shape[1], delay_mask.shape[1])
+        valid_mask = delay_mask[:, :mask_len, :]
+        valid_input = input_ids[:, :mask_len, :]
+
+        # Overwrite the respective parts of the input
+        input_ids[:, :mask_len, :] = torch.where(valid_mask == pad_id, valid_input, valid_mask)
+
+        return input_ids
+
+    def _main_generate_loop(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        generation_config: Optional[GenerationConfig] = None,
+        logits_processor: Optional[LogitsProcessorList] = None,
+        stopping_criteria: Optional[StoppingCriteriaList] = None,
+        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
+        synced_gpus: Optional[bool] = None,
+        assistant_model: Optional["PreTrainedModel"] = None,
+        streamer: Optional["BaseStreamer"] = None,
+        negative_prompt_ids: Optional[torch.Tensor] = None,
+        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+        use_model_defaults: Optional[bool] = None,
+        custom_generate: Optional[str] = None,
+        **kwargs,
+    ):
+        # ********** mostly taken from main generate function up to calling the different methods (see NOTE) **********
+        # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
+        generation_mode_kwargs = self._extract_generation_mode_kwargs(
+            custom_generate,
+            kwargs,
+            synced_gpus,
+            assistant_model,
+            streamer,
+        )
+        generation_config, model_kwargs = self._prepare_generation_config(
+            generation_config, use_model_defaults, **kwargs
+        )
+        generation_mode = generation_config.get_generation_mode(assistant_model)
+
+        self._validate_model_kwargs(model_kwargs.copy())
+        self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)
+
+        # 2. Set generation parameters if not already defined
+        if synced_gpus is None:
+            synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
+
+        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+        # 3. Define model inputs
+        kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
+        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+            inputs, generation_config.bos_token_id, model_kwargs
+        )
+        batch_size = inputs_tensor.shape[0]
+
+        device = inputs_tensor.device
+        self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
+
+        # 4. Define other model kwargs
+        if "encoder_outputs" not in model_kwargs:
+            # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
+            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
+                inputs_tensor, model_kwargs, model_input_name, generation_config
+            )
+
+        # 5. Prepare `input_ids` which will be used for auto-regressive generation
+        input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
+            batch_size=batch_size,
+            model_input_name=model_input_name,
+            model_kwargs=model_kwargs,
+            decoder_start_token_id=generation_config._decoder_start_token_tensor,
+            device=inputs_tensor.device,
+        )
+
+        if generation_config.token_healing:
+            input_ids = self.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer"))
+
+        if streamer is not None:
+            streamer.put(input_ids.cpu())
+
+        # 6. Prepare `max_length` depending on other stopping criteria.
+        # NOTE: incorrect `input_ids.shape[1]` previously
+        input_ids_length = input_ids.shape[-1]
+        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+        has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
+        generation_config = self._prepare_generated_length(
+            generation_config=generation_config,
+            has_default_max_length=has_default_max_length,
+            has_default_min_length=has_default_min_length,
+            model_input_name=model_input_name,
+            inputs_tensor=inputs_tensor,
+            input_ids_length=input_ids_length,
+        )
+
+        # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
+        # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
+        # dynamically overrides this value as it can need more than the last token logits
+        if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
+            model_kwargs["logits_to_keep"] = 1
+
+        self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
+
+        # 7. Prepare the cache.
+        # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
+        # - different models have a different cache name expected by the model (default = "past_key_values")
+        # - `max_length`, prepared above, is used to determine the maximum cache length
+        max_cache_length = generation_config.max_length - 1
+        if (
+            inputs_tensor.shape[1] != input_ids_length
+            and model_input_name == "inputs_embeds"
+            and not self.config.is_encoder_decoder
+        ):
+            max_cache_length += inputs_tensor.shape[1]
+        self._prepare_cache_for_generation(
+            generation_config, model_kwargs, generation_mode, batch_size, max_cache_length
+        )
+
+        # 8. prepare logits processors and stopping criteria
+        prepared_logits_processor = self._get_logits_processor(
+            generation_config=generation_config,
+            input_ids_seq_length=input_ids_length,
+            encoder_input_ids=inputs_tensor,
+            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+            logits_processor=logits_processor,
+            device=inputs_tensor.device,
+            model_kwargs=model_kwargs,
+            negative_prompt_ids=negative_prompt_ids,
+            negative_prompt_attention_mask=negative_prompt_attention_mask,
+        )
+        prepared_stopping_criteria = self._get_stopping_criteria(
+            generation_config=generation_config,
+            stopping_criteria=stopping_criteria,
+            tokenizer=generation_mode_kwargs.get("tokenizer"),
+        )
+
+        # Set model_kwargs `use_cache` so we can use it later in forward runs
+        model_kwargs["use_cache"] = generation_config.use_cache
+        # ******************* taken from main generate function up to calling the different methods *******************
+
+        # Prepare inner 2D logic in generation loop
+        input_ids = input_ids.reshape(-1, input_ids.shape[-1])
+
+        # 10. go into different generation modes
+        if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
+            # 11. expand input_ids with `num_return_sequences` additional sequences per batch
+            if generation_config.num_return_sequences > 1:
+                raise ValueError("`num_return_sequences>1` is incompatible with Dia.")
+
+            # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
+            return self._sample(
+                input_ids,
+                logits_processor=prepared_logits_processor,
+                stopping_criteria=prepared_stopping_criteria,
+                generation_config=generation_config,
+                **generation_mode_kwargs,
+                **model_kwargs,
+            )
+        else:
+            raise ValueError(
+                "Got incompatible mode for generation, should be one of greedy or sampling. "
+                "Ensure that beam search is de-activated by setting `num_beams=1`."
+            )
+
+    @torch.no_grad()
+    def generate(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        generation_config: Optional[GenerationConfig] = None,
+        logits_processor: Optional[LogitsProcessorList] = None,
+        stopping_criteria: Optional[StoppingCriteriaList] = None,
+        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
+        synced_gpus: Optional[bool] = None,
+        assistant_model: Optional["PreTrainedModel"] = None,
+        streamer: Optional["BaseStreamer"] = None,
+        negative_prompt_ids: Optional[torch.Tensor] = None,
+        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+        use_model_defaults: Optional[bool] = None,
+        custom_generate: Optional[str] = None,
+        **kwargs,
+    ) -> Union[GenerateOutput, torch.LongTensor]:
+        # We expect the initial input ids to be the complete mask (delayed input)
+        delay_mask = kwargs.get("decoder_input_ids")
+        if delay_mask is not None:
+            delay_mask = delay_mask.clone()
+
+        output = self._main_generate_loop(
+            inputs=inputs,
+            generation_config=generation_config,
+            logits_processor=logits_processor,
+            stopping_criteria=stopping_criteria,
+            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+            synced_gpus=synced_gpus,
+            assistant_model=assistant_model,
+            streamer=streamer,
+            negative_prompt_ids=negative_prompt_ids,
+            negative_prompt_attention_mask=negative_prompt_attention_mask,
+            use_model_defaults=use_model_defaults,
+            custom_generate=custom_generate,
+            **kwargs,
+        )
+
+        return_dict_in_generate = not isinstance(output, torch.Tensor)
+
+        if return_dict_in_generate:
+            output_sequences = output.sequences
+        else:
+            output_sequences = output
+
+        # Reshape from 2D (bsz * channels, seq_len) to 3D (bsz, seq_len, channels)
+        num_channels = self.config.decoder_config.num_channels
+        bsz = output_sequences.shape[0] // num_channels
+        output_sequences = output_sequences.reshape(bsz, num_channels, -1).transpose(1, 2)
+
+        # Apply delay mask
+        output_sequences = self.apply_delay_mask(output_sequences, self.config.pad_token_id, delay_mask)
+
+        if return_dict_in_generate:
+            output.sequences = output_sequences
+        else:
+            output = output_sequences
+
+        return output
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/modeling_dia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/modeling_dia.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf662b224aabd884521d7e14b8a167886377f4b5
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/modeling_dia.py
@@ -0,0 +1,958 @@
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+#           This file was automatically generated from src/transformers/models/dia/modular_dia.py.
+#               Do NOT edit this file manually as any edits will be overwritten by the generation of
+#             the file from the modular. If any change should be done, please apply the change to the
+#                          modular_dia.py file directly. One of our CI enforces this.
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+    BaseModelOutput,
+    BaseModelOutputWithPastAndCrossAttentions,
+    Seq2SeqLMOutput,
+    Seq2SeqModelOutput,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+    TransformersKwargs,
+    auto_docstring,
+    can_return_tuple,
+    is_torch_flex_attn_available,
+    is_torchdynamo_compiling,
+    logging,
+)
+from ...utils.deprecation import deprecate_kwarg
+from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
+from .generation_dia import DiaGenerationMixin
+
+
+if is_torch_flex_attn_available():
+    from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+@auto_docstring
+class DiaPreTrainedModel(PreTrainedModel):
+    config: DiaConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _supports_flash_attn = True
+    _supports_sdpa = True
+    _supports_flex_attn = True
+    _can_compile_fullgraph = True
+    main_input_name = "input_ids"
+    _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
+
+
+class DiaMultiChannelEmbedding(nn.Module):
+    """In order to efficiently compute the audio embedding from the 9 different channels,
+    we vectorize the embedding process by using a single embedding layer and an offset.
+    Example:
+    - num_embeds = 4
+    - vocab_size = 8
+    - num_channels = 3
+    We would have offsets = [0, 8, 16]
+    If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
+    then tokens = audio_codes + offsets
+                = [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
+    This allows us to use a single embedding layer for all channels.
+    """
+
+    def __init__(self, config: DiaDecoderConfig):
+        super().__init__()
+        self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
+        self.hidden_size = config.hidden_size
+        self.num_channels = config.num_channels
+        offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size  # (C,)
+        self.register_buffer("offsets", offsets, persistent=False)
+
+    def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
+        tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
+        embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
+        return embeds.sum(dim=2)
+
+
+class DiaMLP(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+
+        self.config = config
+        self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
+        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
+        self.activation_fn = ACT2FN[config.hidden_act]
+
+    def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+        up_states = self.gate_up_proj(hidden_states)
+
+        gate, up_states = up_states.chunk(2, dim=-1)
+        up_states = up_states * self.activation_fn(gate)
+
+        return self.down_proj(up_states)
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class DiaRMSNorm(nn.Module):
+    def __init__(self, hidden_size, eps=1e-6):
+        """
+        DiaRMSNorm is equivalent to T5LayerNorm
+        """
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states):
+        input_dtype = hidden_states.dtype
+        hidden_states = hidden_states.to(torch.float32)
+        variance = hidden_states.pow(2).mean(-1, keepdim=True)
+        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+        return self.weight * hidden_states.to(input_dtype)
+
+    def extra_repr(self):
+        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class DiaRotaryEmbedding(nn.Module):
+    inv_freq: torch.Tensor  # fix linting for `register_buffer`
+
+    def __init__(self, config: DiaConfig, device=None):
+        super().__init__()
+        # BC: "rope_type" was originally "type"
+        if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+        else:
+            self.rope_type = "default"
+        self.max_seq_len_cached = config.max_position_embeddings
+        self.original_max_seq_len = config.max_position_embeddings
+
+        self.config = config
+        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
+        self.original_inv_freq = self.inv_freq
+
+    @torch.no_grad()
+    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
+    def forward(self, x, position_ids):
+        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+        position_ids_expanded = position_ids[:, None, :].float()
+
+        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+        with torch.autocast(device_type=device_type, enabled=False):  # Force float32
+            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+            emb = torch.cat((freqs, freqs), dim=-1)
+            cos = emb.cos() * self.attention_scaling
+            sin = emb.sin() * self.attention_scaling
+
+        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+    """Rotates half the hidden dims of the input."""
+    x1 = x[..., : x.shape[-1] // 2]
+    x2 = x[..., x.shape[-1] // 2 :]
+    return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+    """Applies Rotary Position Embedding to the query and key tensors.
+
+    Args:
+        q (`torch.Tensor`): The query tensor.
+        k (`torch.Tensor`): The key tensor.
+        cos (`torch.Tensor`): The cosine part of the rotary embedding.
+        sin (`torch.Tensor`): The sine part of the rotary embedding.
+        position_ids (`torch.Tensor`, *optional*):
+            Deprecated and unused.
+        unsqueeze_dim (`int`, *optional*, defaults to 1):
+            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+    Returns:
+        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+    """
+    cos = cos.unsqueeze(unsqueeze_dim)
+    sin = sin.unsqueeze(unsqueeze_dim)
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+    """
+    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+    """
+    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+    if n_rep == 1:
+        return hidden_states
+    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+    module: nn.Module,
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attention_mask: Optional[torch.Tensor],
+    scaling: float,
+    dropout: float = 0.0,
+    **kwargs: Unpack[TransformersKwargs],
+):
+    key_states = repeat_kv(key, module.num_key_value_groups)
+    value_states = repeat_kv(value, module.num_key_value_groups)
+
+    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+    if attention_mask is not None:
+        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+        attn_weights = attn_weights + causal_mask
+
+    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+    attn_output = torch.matmul(attn_weights, value_states)
+    attn_output = attn_output.transpose(1, 2).contiguous()
+
+    return attn_output, attn_weights
+
+
+class DiaSelfAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False):
+        super().__init__()
+        self.config = config
+        self.layer_idx = layer_idx
+        self.hidden_size = config.hidden_size
+        self.num_heads = self.config.num_attention_heads
+        self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
+        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+        self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
+        self.scaling = 1
+        self.attention_dropout = 0.0
+        self.is_causal = is_causal
+
+        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_embeddings: tuple[torch.Tensor, torch.Tensor],
+        attention_mask: Optional[torch.Tensor],
+        past_key_values: Optional[Cache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        input_shape = hidden_states.shape[:-1]
+        hidden_shape = (*input_shape, -1, self.head_dim)
+
+        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+        cos, sin = position_embeddings
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_values is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+        attention_interface: Callable = eager_attention_forward
+        if self.config._attn_implementation != "eager":
+            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+        attn_output, attn_weights = attention_interface(
+            self,
+            query_states,
+            key_states,
+            value_states,
+            attention_mask,
+            dropout=0.0 if not self.training else self.attention_dropout,
+            scaling=self.scaling,
+            **kwargs,
+        )
+
+        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+        attn_output = self.o_proj(attn_output)
+        return attn_output, attn_weights
+
+
+class DiaCrossAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config: DiaDecoderConfig, layer_idx: int):
+        super().__init__()
+        self.config = config
+        self.layer_idx = layer_idx
+        self.hidden_size = config.hidden_size
+        self.cross_hidden_size = config.cross_hidden_size
+        self.num_heads = self.config.cross_num_attention_heads
+        self.num_key_value_heads = self.config.cross_num_key_value_heads
+        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+        self.head_dim = config.cross_head_dim
+        self.scaling = 1
+        self.attention_dropout = 0.0
+        self.is_causal = False
+
+        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+        self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+        self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        cross_attention_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[EncoderDecoderCache] = None,
+        **kwargs: Unpack[FlashAttentionKwargs],
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+        input_shape = hidden_states.shape[:-1]
+        hidden_shape = (*input_shape, -1, self.head_dim)
+        cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
+
+        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+        is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
+        if past_key_values is not None and is_updated:
+            # reuse k,v, cross_attentions
+            key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
+            value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
+        else:
+            key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
+            value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
+
+            if past_key_values is not None:
+                # save all states to the cache
+                key_states, value_states = past_key_values.cross_attention_cache.update(
+                    key_states,
+                    value_states,
+                    self.layer_idx,
+                )
+                # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+                past_key_values.is_updated[self.layer_idx] = True
+
+        attention_interface: Callable = eager_attention_forward
+        if self.config._attn_implementation != "eager":
+            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+        attn_output, attn_weights = attention_interface(
+            self,
+            query_states,
+            key_states,
+            value_states,
+            attention_mask,
+            scaling=self.scaling,
+            **kwargs,
+        )
+
+        attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
+        attn_output = self.o_proj(attn_output)
+        return attn_output, attn_weights
+
+
+class DiaEncoderLayer(GradientCheckpointingLayer):
+    def __init__(self, config: DiaEncoderConfig, layer_idx: int):
+        super().__init__()
+        self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+        self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
+        self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+        self.mlp = DiaMLP(config)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
+        attention_mask: Optional[torch.Tensor] = None,
+        **kwargs: Unpack[FlashAttentionKwargs],
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+        residual = hidden_states
+        normed_states = self.pre_sa_norm(hidden_states)
+        self_attn_output, self_attn_weights = self.self_attention(
+            normed_states,
+            position_embeddings=position_embeddings,
+            attention_mask=attention_mask,
+            **kwargs,
+        )
+        hidden_states = residual + self_attn_output
+
+        residual = hidden_states
+        normed_states = self.post_sa_norm(hidden_states)
+        mlp_out = self.mlp(normed_states)
+        hidden_states = residual + mlp_out
+
+        return hidden_states, self_attn_weights
+
+
+class DiaEncoder(DiaPreTrainedModel):
+    def __init__(self, config: DiaEncoderConfig):
+        super().__init__(config)
+        self.config = config
+
+        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
+        self.layers = nn.ModuleList(
+            [DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+        )
+        self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+        self.rotary_embeddings = DiaRotaryEmbedding(config)
+
+    @auto_docstring
+    @can_return_tuple
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = False,
+        output_hidden_states: Optional[bool] = False,
+        **kwargs: Unpack[FlashAttentionKwargs],
+    ) -> Union[BaseModelOutput, tuple]:
+        hidden_states = self.embedding(input_ids)
+
+        # RoPE
+        # Note: We expect right padding and hence always generate
+        # the position ids on the fly to reduce preparation overhead
+        position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
+        position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
+
+        attention_mask = self._update_full_mask(
+            attention_mask,
+            hidden_states,
+        )
+
+        encoder_states = () if output_hidden_states else None
+        all_attentions = () if output_attentions else None
+
+        for encoder_layer in self.layers:
+            if output_hidden_states:
+                encoder_states = encoder_states + (hidden_states,)
+
+            layer_outputs = encoder_layer(
+                hidden_states,
+                position_embeddings=position_embeddings,
+                attention_mask=attention_mask,
+                **kwargs,
+            )
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_attentions = all_attentions + (layer_outputs[1],)
+
+        hidden_states = self.norm(hidden_states)
+
+        if output_hidden_states:
+            encoder_states += (hidden_states,)
+
+        return BaseModelOutput(
+            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+        )
+
+    # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+    def _update_full_mask(
+        self,
+        attention_mask: Union[torch.Tensor, None],
+        inputs_embeds: torch.Tensor,
+    ):
+        if attention_mask is not None:
+            if self.config._attn_implementation == "flash_attention_2":
+                attention_mask = attention_mask if 0 in attention_mask else None
+            elif self.config._attn_implementation == "sdpa":
+                # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+                # the manual implementation that requires a 4D causal mask in all cases.
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+            elif self.config._attn_implementation == "flex_attention":
+                if isinstance(attention_mask, torch.Tensor):
+                    attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+            else:
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+        return attention_mask
+
+
+class DiaDecoderLayer(GradientCheckpointingLayer):
+    def __init__(self, config: DiaDecoderConfig, layer_idx: int):
+        super().__init__()
+        self.embed_dim = config.hidden_size
+        self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
+        self.cross_attention = DiaCrossAttention(config, layer_idx)
+        self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+        self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+        self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+        self.mlp = DiaMLP(config)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[EncoderDecoderCache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+        self_attn_cache = past_key_values
+        if isinstance(self_attn_cache, EncoderDecoderCache):
+            self_attn_cache = self_attn_cache.self_attention_cache
+
+        residual = hidden_states
+        normed_states = self.pre_sa_norm(hidden_states)
+        self_attn_output, self_attn_weights = self.self_attention(
+            normed_states,
+            position_embeddings,
+            attention_mask,
+            # Needs to be an arg in order to function properly
+            # on inplace operations to be carried (e.g. compile)
+            self_attn_cache,
+            cache_position=cache_position,
+            **kwargs,
+        )
+        hidden_states = residual + self_attn_output
+
+        residual = hidden_states
+        normed_states = self.pre_ca_norm(hidden_states)
+        cross_states, cross_attn_weights = self.cross_attention(
+            normed_states,
+            encoder_hidden_states,
+            attention_mask=encoder_attention_mask,
+            past_key_values=past_key_values,
+            **kwargs,
+        )
+        hidden_states = residual + cross_states
+
+        residual = hidden_states
+        normed_states = self.pre_mlp_norm(hidden_states)
+        mlp_out = self.mlp(normed_states)
+        hidden_states = residual + mlp_out
+
+        return hidden_states, self_attn_weights, cross_attn_weights
+
+
+class DiaDecoder(DiaPreTrainedModel):
+    """Transformer Decoder Stack using DenseGeneral."""
+
+    def __init__(self, config: DiaDecoderConfig):
+        super().__init__(config)
+        self.num_channels = config.num_channels
+        self.vocab_size = config.vocab_size
+        self.embeddings = DiaMultiChannelEmbedding(config)
+        self.rotary_embeddings = DiaRotaryEmbedding(config)
+        self.layers = nn.ModuleList(
+            [DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+        )
+        self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+
+    @auto_docstring
+    @can_return_tuple
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[EncoderDecoderCache] = None,
+        output_attentions: Optional[bool] = False,
+        output_hidden_states: Optional[bool] = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]:
+        r"""
+        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
+            The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
+
+            [What are input IDs?](../glossary#input-ids)
+        """
+
+        batch_size, seq_length = input_ids.size()[:-1]
+        past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+        if cache_position is None:
+            cache_position = torch.arange(
+                past_key_values_length, past_key_values_length + seq_length, device=input_ids.device
+            )
+        if position_ids is None:
+            position_ids = cache_position[None, :]
+
+        # RoPE
+        hidden_states = self.embeddings(input_ids)
+        position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
+
+        if attention_mask is None and not is_torchdynamo_compiling():
+            # required mask seq length can be calculated via length of past cache
+            mask_seq_length = past_key_values_length + seq_length
+            attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
+
+        attention_mask = create_causal_mask(
+            config=self.config,
+            input_embeds=hidden_states,
+            attention_mask=attention_mask,
+            cache_position=cache_position,
+            past_key_values=past_key_values,
+            position_ids=position_ids,
+        )
+        encoder_attention_mask = self._update_cross_attn_mask(
+            encoder_hidden_states,
+            encoder_attention_mask,
+            hidden_states.shape[:2],
+            hidden_states,
+        )
+
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+        for layer in self.layers:
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            layer_outputs = layer(
+                hidden_states,
+                position_embeddings,
+                attention_mask,
+                encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                past_key_values=past_key_values,
+                cache_position=cache_position,
+                **kwargs,
+            )
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_self_attns = all_self_attns + (layer_outputs[1],)
+
+                if encoder_hidden_states is not None:
+                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+        hidden_states = self.norm(hidden_states)
+
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=past_key_values,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            cross_attentions=all_cross_attentions,
+        )
+
+    # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+    def _update_cross_attn_mask(
+        self,
+        encoder_hidden_states: Union[torch.Tensor, None],
+        encoder_attention_mask: Union[torch.Tensor, None],
+        input_shape: torch.Size,
+        inputs_embeds: torch.Tensor,
+    ):
+        # expand encoder attention mask
+        if encoder_hidden_states is not None and encoder_attention_mask is not None:
+            if self.config._attn_implementation == "flash_attention_2":
+                encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+            elif self.config._attn_implementation == "sdpa":
+                # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+                # the manual implementation that requires a 4D causal mask in all cases.
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+                    encoder_attention_mask,
+                    inputs_embeds.dtype,
+                    tgt_len=input_shape[-1],
+                )
+            elif self.config._attn_implementation == "flex_attention":
+                if isinstance(encoder_attention_mask, torch.Tensor):
+                    encoder_attention_mask = make_flex_block_causal_mask(
+                        encoder_attention_mask,
+                        query_length=input_shape[-1],
+                        is_causal=False,
+                    )
+            else:
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                encoder_attention_mask = _prepare_4d_attention_mask(
+                    encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+                )
+
+        return encoder_attention_mask
+
+
+@auto_docstring(
+    custom_intro="""
+    The bare Dia model outputting raw hidden-states without any specific head on top.
+    """
+)
+class DiaModel(DiaPreTrainedModel):
+    def __init__(self, config: DiaConfig):
+        super().__init__(config)
+        self.config = config
+        self.encoder = DiaEncoder(config.encoder_config)
+        self.decoder = DiaDecoder(config.decoder_config)
+        self.post_init()
+
+    def get_encoder(self):
+        return self.encoder
+
+    @auto_docstring
+    @can_return_tuple
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        decoder_position_ids: Optional[torch.LongTensor] = None,
+        decoder_attention_mask: Optional[torch.LongTensor] = None,
+        encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
+        past_key_values: Optional[EncoderDecoderCache] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ) -> Union[tuple, Seq2SeqModelOutput]:
+        r"""
+        decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
+        or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
+            1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
+            the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
+            tened audio logits which are used to calculate the loss.
+
+            2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
+            Dia to calculate embeddings and subsequent steps more efficiently.
+
+            If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
+            `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
+            [`DiaProcessor.__call__`] for more details.
+
+            [What are decoder input IDs?](../glossary#decoder-input-ids)
+        decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+            Indices of positions of each input sequence tokens in the position embeddings.
+            Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
+
+            [What are position IDs?](../glossary#position-ids)
+        """
+
+        if input_ids is None and encoder_outputs is None:
+            raise ValueError(
+                "You should either provide text ids or the cached text encodings. Neither has been found."
+            )
+
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+        if self.is_gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        if use_cache and past_key_values is None:
+            past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+
+        if encoder_outputs is None:
+            encoder_outputs = self.encoder(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                **kwargs,
+            )
+        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
+        elif not isinstance(encoder_outputs, BaseModelOutput):
+            encoder_outputs = BaseModelOutput(
+                last_hidden_state=encoder_outputs[0],
+                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+            )
+
+        # On default we initialize the decoder with bos tokens if nothing has been provided
+        bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
+        if decoder_input_ids is None:
+            decoder_input_ids = torch.full(
+                size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device
+            )
+        # Ensure 3D
+        if decoder_input_ids.ndim == 2:
+            decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
+
+        decoder_outputs = self.decoder(
+            input_ids=decoder_input_ids,
+            position_ids=decoder_position_ids,
+            attention_mask=decoder_attention_mask,
+            encoder_hidden_states=encoder_outputs[0],
+            encoder_attention_mask=attention_mask,
+            past_key_values=past_key_values,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            use_cache=use_cache,
+            cache_position=cache_position,
+            **kwargs,
+        )
+
+        return Seq2SeqModelOutput(
+            last_hidden_state=decoder_outputs.last_hidden_state,
+            past_key_values=decoder_outputs.past_key_values,
+            decoder_hidden_states=decoder_outputs.hidden_states,
+            decoder_attentions=decoder_outputs.attentions,
+            cross_attentions=decoder_outputs.cross_attentions,
+            encoder_last_hidden_state=encoder_outputs[0],
+            encoder_hidden_states=encoder_outputs.hidden_states,
+            encoder_attentions=encoder_outputs.attentions,
+        )
+
+
+@auto_docstring(
+    custom_intro="""
+    The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
+    """
+)
+class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
+    base_model_prefix = "model"
+
+    def __init__(self, config: DiaConfig):
+        super().__init__(config)
+        self.config = config
+        self.model = DiaModel(config)
+
+        self.num_channels = config.decoder_config.num_channels
+        self.vocab_size = config.decoder_config.vocab_size
+        self.logits_dense = nn.Linear(
+            config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
+        )
+        self.loss_type = "ForMaskedLM"
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_encoder(self):
+        return self.model.get_encoder()
+
+    def get_decoder(self):
+        return self.model.get_decoder()
+
+    @auto_docstring
+    @can_return_tuple
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        decoder_position_ids: Optional[torch.LongTensor] = None,
+        decoder_attention_mask: Optional[torch.LongTensor] = None,
+        encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
+        past_key_values: Optional[EncoderDecoderCache] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        labels: Optional[torch.LongTensor] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ) -> Union[tuple, Seq2SeqLMOutput]:
+        r"""
+        decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
+        or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
+            1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
+            the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
+            tened audio logits which are used to calculate the loss.
+
+            2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
+            Dia to calculate embeddings and subsequent steps more efficiently.
+
+            If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
+            `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
+            [`DiaProcessor.__call__`] for more details.
+
+            [What are decoder input IDs?](../glossary#decoder-input-ids)
+        decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+            Indices of positions of each input sequence tokens in the position embeddings.
+            Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
+
+            [What are position IDs?](../glossary#position-ids)
+        labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
+            Labels for computing the masked language modeling loss. Indices should either be in
+            `[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
+            are ignored (masked).
+        """
+
+        outputs = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            decoder_input_ids=decoder_input_ids,
+            decoder_position_ids=decoder_position_ids,
+            decoder_attention_mask=decoder_attention_mask,
+            encoder_outputs=encoder_outputs,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            cache_position=cache_position,
+            **kwargs,
+        )
+
+        last_hidden_state = outputs[0]
+        batch_size = last_hidden_state.shape[0]
+        # 3D <-> 2D makes it necessary to prioritize channel dim
+        audio_logits = (
+            self.logits_dense(last_hidden_state)
+            .view((batch_size, -1, self.num_channels, self.vocab_size))
+            .transpose(1, 2)
+            .contiguous()
+            .view(batch_size * self.num_channels, -1, self.vocab_size)
+        )
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
+
+        return Seq2SeqLMOutput(
+            loss=loss,
+            logits=audio_logits,
+            past_key_values=outputs.past_key_values,
+            decoder_hidden_states=outputs.decoder_hidden_states,
+            decoder_attentions=outputs.decoder_attentions,
+            cross_attentions=outputs.cross_attentions,
+            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+            encoder_hidden_states=outputs.encoder_hidden_states,
+            encoder_attentions=outputs.encoder_attentions,
+        )
+
+
+__all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/modular_dia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/modular_dia.py
new file mode 100644
index 0000000000000000000000000000000000000000..f99d32a01d9cffab79bd72852d776447e084681b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/modular_dia.py
@@ -0,0 +1,773 @@
+# coding=utf-8
+# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Dia model."""
+
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...cache_utils import DynamicCache, EncoderDecoderCache
+from ...masking_utils import create_causal_mask
+from ...modeling_attn_mask_utils import (
+    _prepare_4d_attention_mask,
+    _prepare_4d_attention_mask_for_sdpa,
+)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+    BaseModelOutput,
+    BaseModelOutputWithPastAndCrossAttentions,
+    Seq2SeqLMOutput,
+    Seq2SeqModelOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
+from ..llama.modeling_llama import (
+    LlamaAttention,
+    LlamaRMSNorm,
+    LlamaRotaryEmbedding,
+    eager_attention_forward,
+)
+from ..phi3.modeling_phi3 import Phi3MLP
+from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
+from .generation_dia import DiaGenerationMixin
+
+
+if is_torch_flex_attn_available():
+    from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+@auto_docstring
+class DiaPreTrainedModel(PreTrainedModel):
+    config: DiaConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _supports_flash_attn = True
+    _supports_sdpa = True
+    _supports_flex_attn = True
+    _can_compile_fullgraph = True
+    main_input_name = "input_ids"
+    _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
+
+
+class DiaMultiChannelEmbedding(nn.Module):
+    """In order to efficiently compute the audio embedding from the 9 different channels,
+    we vectorize the embedding process by using a single embedding layer and an offset.
+    Example:
+    - num_embeds = 4
+    - vocab_size = 8
+    - num_channels = 3
+    We would have offsets = [0, 8, 16]
+    If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
+    then tokens = audio_codes + offsets
+                = [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
+    This allows us to use a single embedding layer for all channels.
+    """
+
+    def __init__(self, config: DiaDecoderConfig):
+        super().__init__()
+        self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
+        self.hidden_size = config.hidden_size
+        self.num_channels = config.num_channels
+        offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size  # (C,)
+        self.register_buffer("offsets", offsets, persistent=False)
+
+    def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
+        tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
+        embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
+        return embeds.sum(dim=2)
+
+
+class DiaMLP(Phi3MLP):
+    pass
+
+
+class DiaRMSNorm(LlamaRMSNorm):
+    pass
+
+
+class DiaRotaryEmbedding(LlamaRotaryEmbedding):
+    pass
+
+
+class DiaSelfAttention(LlamaAttention):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False):
+        nn.Module.__init__(self)
+        self.config = config
+        self.layer_idx = layer_idx
+        self.hidden_size = config.hidden_size
+        self.num_heads = self.config.num_attention_heads
+        self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
+        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+        self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
+        self.scaling = 1
+        self.attention_dropout = 0.0
+        self.is_causal = is_causal
+
+        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+
+class DiaCrossAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config: DiaDecoderConfig, layer_idx: int):
+        super().__init__()
+        self.config = config
+        self.layer_idx = layer_idx
+        self.hidden_size = config.hidden_size
+        self.cross_hidden_size = config.cross_hidden_size
+        self.num_heads = self.config.cross_num_attention_heads
+        self.num_key_value_heads = self.config.cross_num_key_value_heads
+        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+        self.head_dim = config.cross_head_dim
+        self.scaling = 1
+        self.attention_dropout = 0.0
+        self.is_causal = False
+
+        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+        self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+        self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        cross_attention_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[EncoderDecoderCache] = None,
+        **kwargs: Unpack[FlashAttentionKwargs],
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+        input_shape = hidden_states.shape[:-1]
+        hidden_shape = (*input_shape, -1, self.head_dim)
+        cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
+
+        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+        is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
+        if past_key_values is not None and is_updated:
+            # reuse k,v, cross_attentions
+            key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
+            value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
+        else:
+            key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
+            value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
+
+            if past_key_values is not None:
+                # save all states to the cache
+                key_states, value_states = past_key_values.cross_attention_cache.update(
+                    key_states,
+                    value_states,
+                    self.layer_idx,
+                )
+                # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+                past_key_values.is_updated[self.layer_idx] = True
+
+        attention_interface: Callable = eager_attention_forward
+        if self.config._attn_implementation != "eager":
+            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+        attn_output, attn_weights = attention_interface(
+            self,
+            query_states,
+            key_states,
+            value_states,
+            attention_mask,
+            scaling=self.scaling,
+            **kwargs,
+        )
+
+        attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
+        attn_output = self.o_proj(attn_output)
+        return attn_output, attn_weights
+
+
+class DiaEncoderLayer(GradientCheckpointingLayer):
+    def __init__(self, config: DiaEncoderConfig, layer_idx: int):
+        super().__init__()
+        self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+        self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
+        self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+        self.mlp = DiaMLP(config)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
+        attention_mask: Optional[torch.Tensor] = None,
+        **kwargs: Unpack[FlashAttentionKwargs],
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+        residual = hidden_states
+        normed_states = self.pre_sa_norm(hidden_states)
+        self_attn_output, self_attn_weights = self.self_attention(
+            normed_states,
+            position_embeddings=position_embeddings,
+            attention_mask=attention_mask,
+            **kwargs,
+        )
+        hidden_states = residual + self_attn_output
+
+        residual = hidden_states
+        normed_states = self.post_sa_norm(hidden_states)
+        mlp_out = self.mlp(normed_states)
+        hidden_states = residual + mlp_out
+
+        return hidden_states, self_attn_weights
+
+
+class DiaEncoder(DiaPreTrainedModel):
+    def __init__(self, config: DiaEncoderConfig):
+        super().__init__(config)
+        self.config = config
+
+        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
+        self.layers = nn.ModuleList(
+            [DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+        )
+        self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+        self.rotary_embeddings = DiaRotaryEmbedding(config)
+
+    @auto_docstring
+    @can_return_tuple
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = False,
+        output_hidden_states: Optional[bool] = False,
+        **kwargs: Unpack[FlashAttentionKwargs],
+    ) -> Union[BaseModelOutput, tuple]:
+        hidden_states = self.embedding(input_ids)
+
+        # RoPE
+        # Note: We expect right padding and hence always generate
+        # the position ids on the fly to reduce preparation overhead
+        position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
+        position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
+
+        attention_mask = self._update_full_mask(
+            attention_mask,
+            hidden_states,
+        )
+
+        encoder_states = () if output_hidden_states else None
+        all_attentions = () if output_attentions else None
+
+        for encoder_layer in self.layers:
+            if output_hidden_states:
+                encoder_states = encoder_states + (hidden_states,)
+
+            layer_outputs = encoder_layer(
+                hidden_states,
+                position_embeddings=position_embeddings,
+                attention_mask=attention_mask,
+                **kwargs,
+            )
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_attentions = all_attentions + (layer_outputs[1],)
+
+        hidden_states = self.norm(hidden_states)
+
+        if output_hidden_states:
+            encoder_states += (hidden_states,)
+
+        return BaseModelOutput(
+            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+        )
+
+    # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+    def _update_full_mask(
+        self,
+        attention_mask: Union[torch.Tensor, None],
+        inputs_embeds: torch.Tensor,
+    ):
+        if attention_mask is not None:
+            if self.config._attn_implementation == "flash_attention_2":
+                attention_mask = attention_mask if 0 in attention_mask else None
+            elif self.config._attn_implementation == "sdpa":
+                # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+                # the manual implementation that requires a 4D causal mask in all cases.
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+            elif self.config._attn_implementation == "flex_attention":
+                if isinstance(attention_mask, torch.Tensor):
+                    attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+            else:
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+        return attention_mask
+
+
+class DiaDecoderLayer(GradientCheckpointingLayer):
+    def __init__(self, config: DiaDecoderConfig, layer_idx: int):
+        super().__init__()
+        self.embed_dim = config.hidden_size
+        self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
+        self.cross_attention = DiaCrossAttention(config, layer_idx)
+        self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+        self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+        self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+        self.mlp = DiaMLP(config)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[EncoderDecoderCache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+        self_attn_cache = past_key_values
+        if isinstance(self_attn_cache, EncoderDecoderCache):
+            self_attn_cache = self_attn_cache.self_attention_cache
+
+        residual = hidden_states
+        normed_states = self.pre_sa_norm(hidden_states)
+        self_attn_output, self_attn_weights = self.self_attention(
+            normed_states,
+            position_embeddings,
+            attention_mask,
+            # Needs to be an arg in order to function properly
+            # on inplace operations to be carried (e.g. compile)
+            self_attn_cache,
+            cache_position=cache_position,
+            **kwargs,
+        )
+        hidden_states = residual + self_attn_output
+
+        residual = hidden_states
+        normed_states = self.pre_ca_norm(hidden_states)
+        cross_states, cross_attn_weights = self.cross_attention(
+            normed_states,
+            encoder_hidden_states,
+            attention_mask=encoder_attention_mask,
+            past_key_values=past_key_values,
+            **kwargs,
+        )
+        hidden_states = residual + cross_states
+
+        residual = hidden_states
+        normed_states = self.pre_mlp_norm(hidden_states)
+        mlp_out = self.mlp(normed_states)
+        hidden_states = residual + mlp_out
+
+        return hidden_states, self_attn_weights, cross_attn_weights
+
+
+class DiaDecoder(DiaPreTrainedModel):
+    """Transformer Decoder Stack using DenseGeneral."""
+
+    def __init__(self, config: DiaDecoderConfig):
+        super().__init__(config)
+        self.num_channels = config.num_channels
+        self.vocab_size = config.vocab_size
+        self.embeddings = DiaMultiChannelEmbedding(config)
+        self.rotary_embeddings = DiaRotaryEmbedding(config)
+        self.layers = nn.ModuleList(
+            [DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+        )
+        self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
+
+    @auto_docstring
+    @can_return_tuple
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[EncoderDecoderCache] = None,
+        output_attentions: Optional[bool] = False,
+        output_hidden_states: Optional[bool] = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]:
+        r"""
+        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
+            The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
+
+            [What are input IDs?](../glossary#input-ids)
+        """
+
+        batch_size, seq_length = input_ids.size()[:-1]
+        past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+        if cache_position is None:
+            cache_position = torch.arange(
+                past_key_values_length, past_key_values_length + seq_length, device=input_ids.device
+            )
+        if position_ids is None:
+            position_ids = cache_position[None, :]
+
+        # RoPE
+        hidden_states = self.embeddings(input_ids)
+        position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
+
+        if attention_mask is None and not is_torchdynamo_compiling():
+            # required mask seq length can be calculated via length of past cache
+            mask_seq_length = past_key_values_length + seq_length
+            attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
+
+        attention_mask = create_causal_mask(
+            config=self.config,
+            input_embeds=hidden_states,
+            attention_mask=attention_mask,
+            cache_position=cache_position,
+            past_key_values=past_key_values,
+            position_ids=position_ids,
+        )
+        encoder_attention_mask = self._update_cross_attn_mask(
+            encoder_hidden_states,
+            encoder_attention_mask,
+            hidden_states.shape[:2],
+            hidden_states,
+        )
+
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+        for layer in self.layers:
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            layer_outputs = layer(
+                hidden_states,
+                position_embeddings,
+                attention_mask,
+                encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                past_key_values=past_key_values,
+                cache_position=cache_position,
+                **kwargs,
+            )
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_self_attns = all_self_attns + (layer_outputs[1],)
+
+                if encoder_hidden_states is not None:
+                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+        hidden_states = self.norm(hidden_states)
+
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=past_key_values,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            cross_attentions=all_cross_attentions,
+        )
+
+    # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+    def _update_cross_attn_mask(
+        self,
+        encoder_hidden_states: Union[torch.Tensor, None],
+        encoder_attention_mask: Union[torch.Tensor, None],
+        input_shape: torch.Size,
+        inputs_embeds: torch.Tensor,
+    ):
+        # expand encoder attention mask
+        if encoder_hidden_states is not None and encoder_attention_mask is not None:
+            if self.config._attn_implementation == "flash_attention_2":
+                encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+            elif self.config._attn_implementation == "sdpa":
+                # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+                # the manual implementation that requires a 4D causal mask in all cases.
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+                    encoder_attention_mask,
+                    inputs_embeds.dtype,
+                    tgt_len=input_shape[-1],
+                )
+            elif self.config._attn_implementation == "flex_attention":
+                if isinstance(encoder_attention_mask, torch.Tensor):
+                    encoder_attention_mask = make_flex_block_causal_mask(
+                        encoder_attention_mask,
+                        query_length=input_shape[-1],
+                        is_causal=False,
+                    )
+            else:
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                encoder_attention_mask = _prepare_4d_attention_mask(
+                    encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+                )
+
+        return encoder_attention_mask
+
+
+@auto_docstring(
+    custom_intro="""
+    The bare Dia model outputting raw hidden-states without any specific head on top.
+    """
+)
+class DiaModel(DiaPreTrainedModel):
+    def __init__(self, config: DiaConfig):
+        super().__init__(config)
+        self.config = config
+        self.encoder = DiaEncoder(config.encoder_config)
+        self.decoder = DiaDecoder(config.decoder_config)
+        self.post_init()
+
+    def get_encoder(self):
+        return self.encoder
+
+    @auto_docstring
+    @can_return_tuple
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        decoder_position_ids: Optional[torch.LongTensor] = None,
+        decoder_attention_mask: Optional[torch.LongTensor] = None,
+        encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
+        past_key_values: Optional[EncoderDecoderCache] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ) -> Union[tuple, Seq2SeqModelOutput]:
+        r"""
+        decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
+        or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
+            1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
+            the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
+            tened audio logits which are used to calculate the loss.
+
+            2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
+            Dia to calculate embeddings and subsequent steps more efficiently.
+
+            If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
+            `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
+            [`DiaProcessor.__call__`] for more details.
+
+            [What are decoder input IDs?](../glossary#decoder-input-ids)
+        decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+            Indices of positions of each input sequence tokens in the position embeddings.
+            Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
+
+            [What are position IDs?](../glossary#position-ids)
+        """
+
+        if input_ids is None and encoder_outputs is None:
+            raise ValueError(
+                "You should either provide text ids or the cached text encodings. Neither has been found."
+            )
+
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+        if self.is_gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        if use_cache and past_key_values is None:
+            past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
+
+        if encoder_outputs is None:
+            encoder_outputs = self.encoder(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                **kwargs,
+            )
+        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
+        elif not isinstance(encoder_outputs, BaseModelOutput):
+            encoder_outputs = BaseModelOutput(
+                last_hidden_state=encoder_outputs[0],
+                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+            )
+
+        # On default we initialize the decoder with bos tokens if nothing has been provided
+        bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
+        if decoder_input_ids is None:
+            decoder_input_ids = torch.full(
+                size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device
+            )
+        # Ensure 3D
+        if decoder_input_ids.ndim == 2:
+            decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
+
+        decoder_outputs = self.decoder(
+            input_ids=decoder_input_ids,
+            position_ids=decoder_position_ids,
+            attention_mask=decoder_attention_mask,
+            encoder_hidden_states=encoder_outputs[0],
+            encoder_attention_mask=attention_mask,
+            past_key_values=past_key_values,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            use_cache=use_cache,
+            cache_position=cache_position,
+            **kwargs,
+        )
+
+        return Seq2SeqModelOutput(
+            last_hidden_state=decoder_outputs.last_hidden_state,
+            past_key_values=decoder_outputs.past_key_values,
+            decoder_hidden_states=decoder_outputs.hidden_states,
+            decoder_attentions=decoder_outputs.attentions,
+            cross_attentions=decoder_outputs.cross_attentions,
+            encoder_last_hidden_state=encoder_outputs[0],
+            encoder_hidden_states=encoder_outputs.hidden_states,
+            encoder_attentions=encoder_outputs.attentions,
+        )
+
+
+@auto_docstring(
+    custom_intro="""
+    The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
+    """
+)
+class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
+    base_model_prefix = "model"
+
+    def __init__(self, config: DiaConfig):
+        super().__init__(config)
+        self.config = config
+        self.model = DiaModel(config)
+
+        self.num_channels = config.decoder_config.num_channels
+        self.vocab_size = config.decoder_config.vocab_size
+        self.logits_dense = nn.Linear(
+            config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
+        )
+        self.loss_type = "ForMaskedLM"
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_encoder(self):
+        return self.model.get_encoder()
+
+    def get_decoder(self):
+        return self.model.get_decoder()
+
+    @auto_docstring
+    @can_return_tuple
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        decoder_position_ids: Optional[torch.LongTensor] = None,
+        decoder_attention_mask: Optional[torch.LongTensor] = None,
+        encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
+        past_key_values: Optional[EncoderDecoderCache] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        labels: Optional[torch.LongTensor] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ) -> Union[tuple, Seq2SeqLMOutput]:
+        r"""
+        decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
+        or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
+            1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
+            the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
+            tened audio logits which are used to calculate the loss.
+
+            2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
+            Dia to calculate embeddings and subsequent steps more efficiently.
+
+            If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
+            `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
+            [`DiaProcessor.__call__`] for more details.
+
+            [What are decoder input IDs?](../glossary#decoder-input-ids)
+        decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
+            Indices of positions of each input sequence tokens in the position embeddings.
+            Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
+
+            [What are position IDs?](../glossary#position-ids)
+        labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
+            Labels for computing the masked language modeling loss. Indices should either be in
+            `[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
+            are ignored (masked).
+        """
+
+        outputs = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            decoder_input_ids=decoder_input_ids,
+            decoder_position_ids=decoder_position_ids,
+            decoder_attention_mask=decoder_attention_mask,
+            encoder_outputs=encoder_outputs,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            cache_position=cache_position,
+            **kwargs,
+        )
+
+        last_hidden_state = outputs[0]
+        batch_size = last_hidden_state.shape[0]
+        # 3D <-> 2D makes it necessary to prioritize channel dim
+        audio_logits = (
+            self.logits_dense(last_hidden_state)
+            .view((batch_size, -1, self.num_channels, self.vocab_size))
+            .transpose(1, 2)
+            .contiguous()
+            .view(batch_size * self.num_channels, -1, self.vocab_size)
+        )
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
+
+        return Seq2SeqLMOutput(
+            loss=loss,
+            logits=audio_logits,
+            past_key_values=outputs.past_key_values,
+            decoder_hidden_states=outputs.decoder_hidden_states,
+            decoder_attentions=outputs.decoder_attentions,
+            cross_attentions=outputs.cross_attentions,
+            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+            encoder_hidden_states=outputs.encoder_hidden_states,
+            encoder_attentions=outputs.encoder_attentions,
+        )
+
+
+__all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/processing_dia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/processing_dia.py
new file mode 100644
index 0000000000000000000000000000000000000000..402f5152a64bda378ccdf5edd512c86fe643145c
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/processing_dia.py
@@ -0,0 +1,474 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Processor class for Dia"""
+
+import math
+from pathlib import Path
+from typing import Optional, Union
+
+from ...audio_utils import AudioInput, make_list_of_audio
+from ...feature_extraction_utils import BatchFeature
+from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
+from ...utils import is_soundfile_available, is_torch_available
+
+
+if is_torch_available():
+    import torch
+
+if is_soundfile_available():
+    import soundfile as sf
+
+
+class DiaAudioKwargs(AudioKwargs, total=False):
+    bos_token_id: int
+    eos_token_id: int
+    pad_token_id: int
+    delay_pattern: list[int]
+    generation: bool
+
+
+class DiaProcessorKwargs(ProcessingKwargs, total=False):
+    audio_kwargs: DiaAudioKwargs
+    _defaults = {
+        "text_kwargs": {
+            "padding": True,
+            "padding_side": "right",
+            "add_special_tokens": False,
+        },
+        "audio_kwargs": {
+            "eos_token_id": 1024,
+            "pad_token_id": 1025,
+            "bos_token_id": 1026,
+            "delay_pattern": [0, 8, 9, 10, 11, 12, 13, 14, 15],
+            "generation": True,
+            "sampling_rate": 44100,
+        },
+        "common_kwargs": {"return_tensors": "pt"},
+    }
+
+
+class DiaProcessor(ProcessorMixin):
+    r"""
+    Constructs a Dia processor which wraps a [`DiaFeatureExtractor`], [`DiaTokenizer`], and a [`DacModel`] into
+    a single processor. It inherits, the audio feature extraction, tokenizer, and audio encode/decode functio-
+    nalities. See [`~DiaProcessor.__call__`], [`~DiaProcessor.encode`], and [`~DiaProcessor.decode`] for more
+    information.
+
+    Args:
+        feature_extractor (`DiaFeatureExtractor`):
+            An instance of [`DiaFeatureExtractor`]. The feature extractor is a required input.
+        tokenizer (`DiaTokenizer`):
+            An instance of [`DiaTokenizer`]. The tokenizer is a required input.
+        audio_tokenizer (`DacModel`):
+            An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is is a required input.
+    """
+
+    feature_extractor_class = "DiaFeatureExtractor"
+    tokenizer_class = "DiaTokenizer"
+    audio_tokenizer_class = "DacModel"
+
+    def __init__(self, feature_extractor, tokenizer, audio_tokenizer):
+        super().__init__(feature_extractor, tokenizer, audio_tokenizer=audio_tokenizer)
+
+    def __call__(
+        self,
+        text: Union[str, list[str]],
+        audio: Optional[AudioInput] = None,
+        output_labels: Optional[bool] = False,
+        **kwargs: Unpack[DiaProcessorKwargs],
+    ):
+        """
+        Main method to prepare text(s) and audio to be fed as input to the model. The `audio` argument is
+        forwarded to the DiaFeatureExtractor's [`~DiaFeatureExtractor.__call__`] and subsequently to the
+        DacModel's [`~DacModel.encode`]. The `text` argument to [`~DiaTokenizer.__call__`]. Please refer
+        to the docstring of the above methods for more information.
+        """
+        if not is_torch_available():
+            raise ValueError(
+                "The `DiaProcessor` relies on the `audio_tokenizer` which requires `torch` but we couldn't "
+                "find it in your environment. You can install torch via `pip install torch`."
+            )
+
+        if text is None:
+            raise ValueError("You need to specify the `text` input to process.")
+
+        output_kwargs = self._merge_kwargs(
+            DiaProcessorKwargs,
+            **kwargs,
+        )
+
+        text_kwargs = output_kwargs["text_kwargs"]
+        audio_kwargs = output_kwargs["audio_kwargs"]
+        common_kwargs = output_kwargs["common_kwargs"]
+
+        return_tensors = common_kwargs.pop("return_tensors", None)
+        if return_tensors != "pt":
+            raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
+
+        data = {}
+
+        # Text
+        if isinstance(text, str):
+            text = [text]
+        elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
+            raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+
+        encodings = self.tokenizer(text, **text_kwargs)
+        data.update(encodings)
+
+        # Audio
+        delay_pattern = audio_kwargs.pop("delay_pattern", None)
+        audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
+        audio_eos_token_id = audio_kwargs.pop("eos_token_id", None)
+        audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
+        generation = audio_kwargs.pop("generation", True)
+        if (
+            audio_bos_token_id is None
+            or audio_eos_token_id is None
+            or audio_pad_token_id is None
+            or delay_pattern is None
+        ):
+            raise ValueError(
+                "To enable processing for Dia, we need the `bos_token_id`, `eos_token_id`, "
+                "`pad_token_id`, and `delay_pattern`. You may have accidentally overwritten one of those."
+            )
+
+        if generation and output_labels:
+            raise ValueError(
+                f"Labels with `generation` is incompatible, got generation={generation}, output_labels={output_labels}."
+            )
+
+        batch_size = data["input_ids"].shape[0]
+        num_channels = len(delay_pattern)
+        max_delay = max(delay_pattern)
+
+        # Voice cloning generation / general training
+        if audio is not None:
+            audio = make_list_of_audio(audio)
+            input_audios = self.feature_extractor(audio, **audio_kwargs)
+
+            compression_rate = math.prod(self.audio_tokenizer.config.downsampling_ratios)
+            max_encoded_sequence_len = input_audios["padding_mask"][0].shape[-1] // compression_rate
+
+            decoder_input_ids = []
+            decoder_attention_mask = []
+            # TODO: dac with batching is currently broken, but non-batch is working
+            # refer to https://gist.github.com/vasqu/643a45b680cf39fd7467271ee2eb6f80 for a validation script
+            for padding_mask, audio in zip(input_audios["padding_mask"], input_audios["input_values"]):
+                # get current length with hop length in mind (as if it were sampled as a single audio)
+                base_pad_len = self.feature_extractor.hop_length
+                current_audio_len = math.ceil(padding_mask.sum(dim=-1) / base_pad_len) * base_pad_len
+
+                encoded_sequence_len = current_audio_len // compression_rate
+                padding_len = max_encoded_sequence_len - encoded_sequence_len
+
+                # compute non-padded forward pass; one extra bos (and eos if training) is added
+                with torch.no_grad():
+                    audio = audio[None, ..., :current_audio_len].to(self.audio_tokenizer.device)
+                    input_ids = self.audio_tokenizer.encode(audio).audio_codes.transpose(1, 2)
+
+                if not generation:
+                    input_ids = torch.nn.functional.pad(
+                        input_ids, pad=(0, 0, 0, 1, 0, 0), mode="constant", value=audio_eos_token_id
+                    )
+
+                # apply padding
+                # +1 for the bos within the real sequence
+                input_ids = torch.nn.functional.pad(
+                    input_ids, pad=(0, 0, padding_len + 1, 0, 0, 0), mode="constant", value=audio_bos_token_id
+                )
+                num_valid_inputs = encoded_sequence_len + 1 + max_delay  # sequence + bos + delay
+                num_valid_inputs += 0 if generation else 1  # eos if training
+                attention_mask = torch.tensor([0] * padding_len + [1] * num_valid_inputs, dtype=torch.long)[None, :]
+
+                decoder_input_ids.append(input_ids)
+                decoder_attention_mask.append(attention_mask)
+
+            decoder_input_ids = torch.cat(decoder_input_ids, dim=0)
+            decoder_attention_mask = torch.cat(decoder_attention_mask, dim=0)
+        # TTS generation
+        elif generation:
+            # all bos to start with TTS
+            decoder_input_ids = torch.full((batch_size, 1, num_channels), audio_bos_token_id, dtype=torch.long)
+
+            # we preemptively add the delay
+            decoder_attention_mask = torch.ones(size=(batch_size, 1 + max_delay), dtype=torch.long)
+        else:
+            raise ValueError("If you try to train, you should provide audio data as well.")
+
+        if batch_size != decoder_input_ids.shape[0]:
+            raise ValueError(
+                f"Need the same amount of samples for both text and audio, but got text samples={batch_size} and "
+                f"audio samples = {decoder_input_ids.shape[0]} instead."
+            )
+
+        # prepare shift indices per delay
+        max_seq_len = decoder_attention_mask.shape[-1]
+        max_audio_len = max_seq_len - max_delay
+        precomputed_idx = self.build_indices(
+            bsz=batch_size,
+            seq_len=max_seq_len,
+            num_channels=num_channels,
+            delay_pattern=delay_pattern,
+            revert=False,
+        )
+
+        # create delay pattern input
+        # the pad token will be used for masking which input is valid for prediction during generation
+        prefill = torch.full(
+            (batch_size, max_seq_len, num_channels),
+            fill_value=audio_pad_token_id,
+            dtype=torch.int,
+        )
+        prefill[:, :max_audio_len] = decoder_input_ids
+
+        delayed_decoder_input_ids = self.apply_audio_delay(
+            audio=prefill,
+            pad_token_id=audio_pad_token_id,
+            bos_token_id=audio_bos_token_id,
+            precomputed_idx=precomputed_idx,
+        )
+
+        data.update({"decoder_input_ids": delayed_decoder_input_ids, "decoder_attention_mask": decoder_attention_mask})
+
+        if output_labels:
+            # Base idea is to shift on the sequence dim
+            labels = data["decoder_input_ids"].clone()[:, 1:]
+            labels[labels == audio_pad_token_id] = -100
+            labels[labels == audio_bos_token_id] = -100
+
+            data["labels"] = labels.transpose(1, 2).reshape(batch_size * num_channels, -1).contiguous().long()
+            data["decoder_input_ids"] = data["decoder_input_ids"][:, :-1]
+            data["decoder_attention_mask"] = data["decoder_attention_mask"][:, :-1]
+
+        return BatchFeature(data=data, tensor_type=return_tensors)
+
+    def batch_decode(
+        self,
+        decoder_input_ids: "torch.Tensor",
+        audio_prompt_len: Optional[int] = None,
+        **kwargs: Unpack[DiaProcessorKwargs],
+    ) -> list["torch.Tensor"]:
+        """
+        Decodes a batch of audio codebook sequences into their respective audio waveforms via the
+        `audio_tokenizer`. See [`~DacModel.decode`] for more information.
+
+        Args:
+            decoder_input_ids (`torch.Tensor`): The complete output sequence of the decoder.
+            audio_prompt_len (`int`): The audio prefix length (e.g. when using voice cloning).
+        """
+        output_kwargs = self._merge_kwargs(
+            DiaProcessorKwargs,
+            **kwargs,
+        )
+        audio_kwargs = output_kwargs["audio_kwargs"]
+
+        delay_pattern = audio_kwargs.pop("delay_pattern", None)
+        audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
+        audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
+        if audio_bos_token_id is None or audio_pad_token_id is None or delay_pattern is None:
+            raise ValueError(
+                "To enable decoding for Dia, we need the `bos_token_id`, `pad_token_id`, "
+                "and `delay_pattern`. You may have accidentally overwritten one of those."
+            )
+
+        # either decode the whole audio sequence or only the generated parts
+        if audio_prompt_len is not None:
+            audio_prompt_len = torch.tensor(audio_prompt_len, device=decoder_input_ids.device, dtype=torch.long)
+            start_of_generation_idx = audio_prompt_len[None].expand(decoder_input_ids.shape[0])
+        else:
+            start_of_generation_idx = (decoder_input_ids[:, :, 0] == audio_bos_token_id).sum(dim=-1)
+        # -1 for the eos token
+        end_of_generation_idx = (
+            decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == audio_pad_token_id).sum(dim=-1) - 1
+        )
+
+        # revert delay
+        bsz, seq_len, num_channels = decoder_input_ids.shape
+        precomputed_idx = self.build_indices(
+            bsz=bsz,
+            seq_len=seq_len,
+            num_channels=num_channels,
+            delay_pattern=delay_pattern,
+            revert=True,
+        )
+
+        output_sequences = self.apply_audio_delay(
+            audio=decoder_input_ids,
+            # We do not care about these values as we cut them out
+            # with `start_of_generation_idx` and `end_of_generation_idx`
+            pad_token_id=-1,
+            bos_token_id=-1,
+            precomputed_idx=precomputed_idx,
+        ).transpose(1, 2)
+
+        # retrieve the correct sequences each
+        audios = []
+        # TODO: see above, dac doesn't work in batches yet
+        with torch.no_grad():
+            for i in range(start_of_generation_idx.shape[0]):
+                output_i = output_sequences[i, :, start_of_generation_idx[i] : end_of_generation_idx[i]][None, ...]
+                output_i = output_i.to(self.audio_tokenizer.device)
+                audio_i = self.audio_tokenizer.decode(audio_codes=output_i).audio_values.cpu().squeeze()
+                audios.append(audio_i)
+
+        return audios
+
+    def decode(
+        self,
+        decoder_input_ids: "torch.Tensor",
+        audio_prompt_len: Optional[int] = None,
+        **kwargs: Unpack[DiaProcessorKwargs],
+    ) -> "torch.Tensor":
+        """
+        Decodes a single sequence of audio codebooks into the respective audio waveform via the
+        `audio_tokenizer`. See [`~DacModel.decode`] and [`~DiaProcessor.batch_decode`] for more information.
+        """
+        if decoder_input_ids.shape[0] != 1:
+            raise ValueError(
+                f"Expecting a single output to be decoded but received {decoder_input_ids.shape[0]} samples instead."
+            )
+
+        return self.batch_decode(decoder_input_ids, audio_prompt_len, **kwargs)[0]
+
+    def get_audio_prompt_len(
+        self,
+        decoder_attention_mask: "torch.Tensor",
+        **kwargs: Unpack[DiaProcessorKwargs],
+    ) -> int:
+        """Utility function to get the audio prompt length."""
+        output_kwargs = self._merge_kwargs(
+            DiaProcessorKwargs,
+            **kwargs,
+        )
+        audio_kwargs = output_kwargs["audio_kwargs"]
+
+        delay_pattern = audio_kwargs.pop("delay_pattern", None)
+        if delay_pattern is None:
+            raise ValueError(
+                "To enable the utility of retrieving the prompt length for Dia, we need the "
+                "`delay_pattern`. You may have accidentally overwritten this."
+            )
+        return decoder_attention_mask.shape[1] - max(delay_pattern)
+
+    # Copied from transformers.models.csm.processing_csm.CsmProcessor.save_audio with Csm->Dia
+    def save_audio(
+        self,
+        audio: AudioInput,
+        saving_path: Union[str, Path, list[Union[str, Path]]],
+        **kwargs: Unpack[DiaProcessorKwargs],
+    ):
+        # TODO: @eustlb, this should be in AudioProcessor
+        if not is_soundfile_available():
+            raise ImportError("Please install `soundfile` to save audio files.")
+
+        # ensure correct audio input
+        audio = make_list_of_audio(audio)
+
+        # ensure correct saving path
+        if isinstance(saving_path, (str, Path)):
+            saving_path = [saving_path]
+        elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
+            raise ValueError("Invalid input path. Please provide a string, or a list of strings")
+
+        if len(audio) != len(saving_path):
+            raise ValueError("The number of audio and saving paths must be the same")
+
+        output_kwargs = self._merge_kwargs(
+            DiaProcessorKwargs,
+            **kwargs,
+        )
+        audio_kwargs = output_kwargs["audio_kwargs"]
+        sampling_rate = audio_kwargs["sampling_rate"]
+
+        for audio_value, p in zip(audio, saving_path):
+            if isinstance(audio_value, torch.Tensor):
+                audio_value = audio_value.cpu().float().numpy()
+            sf.write(p, audio_value, sampling_rate)
+
+    @staticmethod
+    def build_indices(
+        bsz: int,
+        seq_len: int,
+        num_channels: int,
+        delay_pattern: list[int],
+        revert: bool = False,
+    ) -> tuple["torch.Tensor", "torch.Tensor"]:
+        """
+        Precompute (sequence_idx, all_idx) so that out[seq, channel] = in[seq - delay[channel], channel]
+        or in[seq, channel] = out[seq + delay[channel], channel] if `revert`.
+        Negative sequence_idx => BOS; sequence_idx >= seq_len => PAD.
+        """
+        delay_array = torch.tensor(delay_pattern, dtype=torch.int32)
+
+        # (0..seq_len-1)
+        sequence_idx = torch.arange(seq_len, dtype=torch.int32)[None, :].expand(bsz, seq_len)[..., None]
+        # + or - delay depending if we delay or revert the delay
+        if not revert:
+            sequence_idx = sequence_idx - delay_array[None, None, :]
+        else:
+            sequence_idx = sequence_idx + delay_array[None, None, :]
+        # if delay goes over the range we clamp back to valid values
+        valid_sequence_idx = torch.clamp(sequence_idx, 0, seq_len - 1)
+
+        batch_idx = torch.arange(bsz, dtype=torch.int32)[:, None, None].expand(bsz, seq_len, num_channels)
+        channel_idx = torch.arange(num_channels, dtype=torch.int32)[None, None, :].expand(bsz, seq_len, num_channels)
+
+        all_idx = torch.stack(
+            [batch_idx.reshape(-1), valid_sequence_idx.reshape(-1), channel_idx.reshape(-1)],
+            dim=1,
+        ).long()
+
+        return sequence_idx, all_idx
+
+    @staticmethod
+    def apply_audio_delay(
+        audio: "torch.Tensor",
+        pad_token_id: int,
+        bos_token_id: int,
+        precomputed_idx: tuple["torch.Tensor", "torch.Tensor"],
+    ) -> "torch.Tensor":
+        """
+        Applies or reverts the delay pattern to batched audio tokens using precomputed indices,
+        inserting BOS where sequence_idx < 0 and PAD where sequence_idx >= seq_len.
+
+        Args:
+            audio: audio tokens of shape [bsz, seq_len, num_channels]
+            pad_token_id: the PAD token
+            bos_token_id: the BOS token
+            precomputed_idx: from `build_indices`
+
+        Returns:
+            final_audio: delayed or reverted audio tokens of shape [bsz, seq_len, num_channels]
+        """
+        # Move everything to the same device
+        device = audio.device
+        sequence_idx, all_idx = precomputed_idx
+        sequence_idx = sequence_idx.to(device)
+        all_idx = all_idx.to(device)
+
+        # Gather per precomputed indices
+        batch_idx, valid_sequence_idx, channel_idx = torch.unbind(all_idx, dim=-1)
+        gathered_audio = audio[batch_idx, valid_sequence_idx, channel_idx].view(audio.size())
+
+        # Mask according to negative sequence_idx => BOS; sequence_idx >= seq_len => PAD
+        mask_bos = sequence_idx < 0
+        mask_pad = sequence_idx >= audio.shape[1]
+        final_audio = torch.where(mask_bos, bos_token_id, torch.where(mask_pad, pad_token_id, gathered_audio))
+
+        return final_audio
+
+
+__all__ = ["DiaProcessor"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/tokenization_dia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/tokenization_dia.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e205906ea709ee2c20f25b0bf6f4fa66ab1f4a4
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dia/tokenization_dia.py
@@ -0,0 +1,118 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for Dia."""
+
+from typing import Optional
+
+from ...tokenization_utils import AddedToken, PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DiaTokenizer(PreTrainedTokenizer):
+    """
+    Construct a Dia tokenizer. Dia simply uses raw bytes utf-8 encoding except for special tokens `[S1]` and `[S2]`.
+
+    This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+    refer to this superclass for more information regarding those methods.
+
+    Args:
+        pad_token (`str`, *optional*, defaults to `""`):
+            The token used for padding, for example when batching sequences of different lengths.
+        unk_token (`str`, *optional*, defaults to `""`):
+            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+            token instead.
+        max_length (`int`, *optional*, defaults to 1024):
+            The maximum length of the sequences when encoding. Sequences longer than this will be truncated.
+        offset (`int`, *optional*, defaults to 0):
+            The offset of the tokenizer.
+    """
+
+    model_input_names = ["input_ids", "attention_mask"]
+
+    def __init__(
+        self,
+        pad_token: Optional[str] = "",
+        unk_token: Optional[str] = "",
+        max_length: Optional[int] = 1024,
+        offset: int = 0,
+        **kwargs,
+    ):
+        # We have no eos/bos tokens but allow padding -- no l/r strip as we treat them as tokens as well
+        pad_token = AddedToken(pad_token) if isinstance(pad_token, str) else pad_token
+        unk_token = AddedToken(unk_token) if isinstance(unk_token, str) else unk_token
+
+        self._utf_vocab_size = 2**8  # utf is 8 bits
+        self._added_tokens_decoder = {0: pad_token, 1: AddedToken("[S1]"), 2: AddedToken("[S2]")}
+        self.offset = offset
+        super().__init__(
+            unk_token=unk_token,
+            pad_token=pad_token,
+            max_length=max_length,
+            **kwargs,
+        )
+
+    @property
+    def vocab_size(self):
+        return self._utf_vocab_size
+
+    def get_vocab(self):
+        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
+        vocab.update(self.added_tokens_encoder)
+        return vocab
+
+    def _tokenize(self, text: str) -> list[str]:
+        """Take as input a string and return a list of strings (tokens) for words/sub-words"""
+        tokens = [chr(i) for i in text.encode("utf-8")]
+        return tokens
+
+    def _convert_token_to_id(self, token):
+        """Converts a token (str) in an id using the vocab."""
+
+        if len(token) != 1:
+            token_id = None
+        else:
+            token_id = ord(token) + self.offset
+
+        return token_id
+
+    def _convert_id_to_token(self, index):
+        """Converts an index (integer) in a token (str) using the vocab."""
+        token = chr(index - self.offset)
+        return token
+
+    def convert_tokens_to_string(self, tokens: list[str]) -> str:
+        """Converts a sequence of tokens (string) in a single string."""
+        bstring = b""
+        for token in tokens:
+            if token in self.added_tokens_decoder:
+                added_token_obj = self.added_tokens_decoder[token]
+                tok_string = str(added_token_obj).encode("utf-8")
+            elif token in self.added_tokens_encoder:
+                tok_string = token.encode("utf-8")
+            else:
+                tok_string = token.encode("utf-8")  # Assume general string token
+            bstring += tok_string
+        string = bstring.decode("utf-8", errors="ignore")
+        return string
+
+    # No vocab file
+    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+        return ()
+
+
+__all__ = ["DiaTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dialogpt/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dialogpt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c162fce0a48bd164bd0e0a615b942ee4805a12aa
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_diffllama import *
+    from .modeling_diffllama import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/configuration_diffllama.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/configuration_diffllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..210607271927ab2f3a7aa1ec1e874fb296c32a73
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/configuration_diffllama.py
@@ -0,0 +1,199 @@
+# coding=utf-8
+# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on Llama implementations in this library and Microsoft's
+# Differential Transformer implementations.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DiffLlama model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+class DiffLlamaConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`DiffLlamaModel`]. It is used to instantiate an DiffLlama
+    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults
+    will yield a similar configuration to that of the [kajuma/DiffLlama-0.3B-handcut](https://huggingface.co/kajuma/DiffLlama-0.3B-handcut).
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 32000):
+            Vocabulary size of the DiffLlama model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`DiffLlamaModel`]
+        hidden_size (`int`, *optional*, defaults to 2048):
+            Dimension of the hidden representations.
+        intermediate_size (`int`, *optional*, defaults to 8192):
+            Dimension of the MLP representations.
+        num_hidden_layers (`int`, *optional*, defaults to 16):
+            Number of hidden layers in the Transformer decoder.
+        num_attention_heads (`int`, *optional*, defaults to 32):
+            Number of attention heads for each attention layer in the Transformer decoder.
+        num_key_value_heads (`int`, *optional*):
+            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+            by meanpooling all the original heads within that group. For more details, check out [this
+            paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+            `num_attention_heads`.
+        hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+            The non-linear activation function (function or string) in the decoder.
+        max_position_embeddings (`int`, *optional*, defaults to 2048):
+            The maximum sequence length that this model might ever be used with.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+            The epsilon used by the rms normalization layers.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models). Only
+            relevant if `config.is_decoder=True`.
+        pad_token_id (`int`, *optional*):
+            Padding token id.
+        bos_token_id (`int`, *optional*, defaults to 1):
+            Beginning of stream token id.
+        eos_token_id (`int`, *optional*, defaults to 2):
+            End of stream token id.
+        tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+            Whether to tie weight embeddings
+        rope_theta (`float`, *optional*, defaults to 10000.0):
+            The base period of the RoPE embeddings.
+        rope_scaling (`Dict`, *optional*):
+            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+            accordingly.
+            Expected contents:
+                `rope_type` (`str`):
+                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+                    'diffllama3'], with 'default' being the original RoPE implementation.
+                `factor` (`float`, *optional*):
+                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+                    original maximum pre-trained length.
+                `original_max_position_embeddings` (`int`, *optional*):
+                    Used with 'dynamic', 'longrope' and 'diffllama3'. The original max position embeddings used during
+                    pretraining.
+                `attention_factor` (`float`, *optional*):
+                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+                    computation. If unspecified, it defaults to value recommended by the implementation, using the
+                    `factor` field to infer the suggested value.
+                `beta_fast` (`float`, *optional*):
+                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+                    ramp function. If unspecified, it defaults to 32.
+                `beta_slow` (`float`, *optional*):
+                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+                    ramp function. If unspecified, it defaults to 1.
+                `short_factor` (`list[float]`, *optional*):
+                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+                    size divided by the number of attention heads divided by 2
+                `long_factor` (`list[float]`, *optional*):
+                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+                    size divided by the number of attention heads divided by 2
+                `low_freq_factor` (`float`, *optional*):
+                    Only used with 'diffllama3'. Scaling factor applied to low frequency components of the RoPE
+                `high_freq_factor` (`float`, *optional*):
+                    Only used with 'diffllama3'. Scaling factor applied to high frequency components of the RoPE
+        attention_bias (`bool`, *optional*, defaults to `False`):
+            Whether to use a bias in the query, key, value and output projection layers during self-attention.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        lambda_std_dev (`float`, *optional*, defaults to 0.1):
+            The standard deviation for initialization of parameter lambda in attention layer.
+        head_dim (`int`, *optional*):
+            The attention head dimension. If None, it will default to hidden_size // num_heads
+
+    ```python
+    >>> from transformers import DiffLlamaModel, DiffLlamaConfig
+
+    >>> # Initializing a DiffLlama diffllama-7b style configuration
+    >>> configuration = DiffLlamaConfig()
+
+    >>> # Initializing a model from the diffllama-7b style configuration
+    >>> model = DiffLlamaModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "diffllama"
+    keys_to_ignore_at_inference = ["past_key_values"]
+
+    def __init__(
+        self,
+        vocab_size=32000,
+        hidden_size=2048,
+        intermediate_size=8192,
+        num_hidden_layers=16,
+        num_attention_heads=32,
+        num_key_value_heads=None,
+        hidden_act="silu",
+        max_position_embeddings=2048,
+        initializer_range=0.02,
+        rms_norm_eps=1e-5,
+        use_cache=True,
+        pad_token_id=None,
+        bos_token_id=1,
+        eos_token_id=2,
+        tie_word_embeddings=False,
+        rope_theta=10000.0,
+        rope_scaling=None,
+        attention_bias=False,
+        attention_dropout=0.0,
+        lambda_std_dev=0.1,
+        head_dim=None,
+        **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+
+        # for backward compatibility
+        if num_key_value_heads is None:
+            num_key_value_heads = num_attention_heads
+
+        self.num_key_value_heads = num_key_value_heads
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.rms_norm_eps = rms_norm_eps
+        self.use_cache = use_cache
+        self.rope_theta = rope_theta
+        self.rope_scaling = rope_scaling
+        self.attention_bias = attention_bias
+        self.attention_dropout = attention_dropout
+        self.lambda_std_dev = lambda_std_dev
+        self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+        # Validate the correctness of rotary position embeddings parameters
+        # BC: if there is a 'type' field, copy it it to 'rope_type'.
+        if self.rope_scaling is not None and "type" in self.rope_scaling:
+            self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+        rope_config_validation(self)
+
+        super().__init__(
+            pad_token_id=pad_token_id,
+            bos_token_id=bos_token_id,
+            eos_token_id=eos_token_id,
+            tie_word_embeddings=tie_word_embeddings,
+            **kwargs,
+        )
+
+
+__all__ = ["DiffLlamaConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/modeling_diffllama.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/modeling_diffllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..094cc375057f71eb51644bf2b49c524613ed22e1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/modeling_diffllama.py
@@ -0,0 +1,767 @@
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+#           This file was automatically generated from src/transformers/models/diffllama/modular_diffllama.py.
+#               Do NOT edit this file manually as any edits will be overwritten by the generation of
+#             the file from the modular. If any change should be done, please apply the change to the
+#                          modular_diffllama.py file directly. One of our CI enforces this.
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on Llama implementations in this library and Microsoft's
+# Differential Transformer implementations.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
+from ...modeling_layers import (
+    GenericForQuestionAnswering,
+    GenericForSequenceClassification,
+    GenericForTokenClassification,
+    GradientCheckpointingLayer,
+)
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ...utils.deprecation import deprecate_kwarg
+from ...utils.generic import check_model_inputs
+from .configuration_diffllama import DiffLlamaConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class DiffLlamaMLP(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size
+        self.intermediate_size = config.intermediate_size
+        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+        self.act_fn = ACT2FN[config.hidden_act]
+
+    def forward(self, x):
+        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+        return down_proj
+
+
+def rotate_half(x):
+    """Rotates half the hidden dims of the input."""
+    x1 = x[..., : x.shape[-1] // 2]
+    x2 = x[..., x.shape[-1] // 2 :]
+    return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+    """Applies Rotary Position Embedding to the query and key tensors.
+
+    Args:
+        q (`torch.Tensor`): The query tensor.
+        k (`torch.Tensor`): The key tensor.
+        cos (`torch.Tensor`): The cosine part of the rotary embedding.
+        sin (`torch.Tensor`): The sine part of the rotary embedding.
+        position_ids (`torch.Tensor`, *optional*):
+            Deprecated and unused.
+        unsqueeze_dim (`int`, *optional*, defaults to 1):
+            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+    Returns:
+        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+    """
+    cos = cos.unsqueeze(unsqueeze_dim)
+    sin = sin.unsqueeze(unsqueeze_dim)
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+    """
+    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+    """
+    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+    if n_rep == 1:
+        return hidden_states
+    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def lambda_init_fn(layer_idx):
+    return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)
+
+
+class DiffLlamaAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None):
+        super().__init__()
+        self.config = config
+        self.layer_idx = layer_idx
+        if layer_idx is None:
+            logger.warning_once(
+                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+                "when creating this class."
+            )
+
+        self.attention_dropout = config.attention_dropout
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+        self.num_key_value_heads = config.num_key_value_heads
+        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+        # under this are not used
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rope_theta = config.rope_theta
+        self.is_causal = True
+
+        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+
+        self.lambda_init = lambda_init_fn(layer_idx)
+        self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+        self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+        self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+        self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+        self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_embeddings: tuple[torch.Tensor, torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+        bsz, target_len, _ = hidden_states.size()
+        q_len = target_len
+
+        query_states = self.q_proj(hidden_states)
+        key_states = self.k_proj(hidden_states)
+        value_states = self.v_proj(hidden_states)
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        cos, sin = position_embeddings
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_values is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+        value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
+        value_states = value_states.repeat(1, 2, 1, 1)
+
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attention_mask is not None:  # no matter the length, we just slice it
+            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+            attn_weights = attn_weights + causal_mask
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+            query_states.dtype
+        )
+        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+            query_states.dtype
+        )
+        lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+        attn_output = torch.matmul(attn_weights, value_states)
+        attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
+
+        attn_output = attn_output1 - lambda_full * attn_output2
+        attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.reshape(bsz, q_len, -1)
+        attn_output = self.o_proj(attn_output)
+        return attn_output, attn_weights
+
+
+class DiffLlamaFlashAttention2(DiffLlamaAttention):
+    """
+    DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays
+    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+    flash attention and deal with padding tokens in case the input contains any of them.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+        self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_embeddings: tuple[torch.Tensor, torch.Tensor],
+        attention_mask: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
+    ) -> tuple[torch.Tensor, None]:
+        if isinstance(past_key_values, StaticCache):
+            raise ValueError(
+                "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+                "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+            )
+
+        bsz, q_len, _ = hidden_states.size()
+
+        query_states = self.q_proj(hidden_states)
+        key_states = self.k_proj(hidden_states)
+        value_states = self.v_proj(hidden_states)
+
+        # Flash attention requires the input to have the shape
+        # batch_size x seq_length x head_dim x hidden_dim
+        # therefore we just need to keep the original shape
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        if position_embeddings is None:
+            logger.warning_once(
+                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+                "removed and `position_embeddings` will be mandatory."
+            )
+            cos, sin = self.rotary_emb(value_states, position_ids)
+        else:
+            cos, sin = position_embeddings
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_values is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+        # to be able to avoid many of these transpose/reshape/view.
+        query_states = query_states.transpose(1, 2)
+        key_states = key_states.transpose(1, 2)
+        value_states = value_states.transpose(1, 2)
+
+        dropout_rate = self.attention_dropout if self.training else 0.0
+
+        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+        # therefore the input hidden states gets silently casted in float32. Hence, we need
+        # cast them back in the correct dtype just to be sure everything works as expected.
+        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+        # in fp32. (DiffLlamaRMSNorm handles it correctly)
+
+        input_dtype = query_states.dtype
+        device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
+        if input_dtype == torch.float32:
+            if torch.is_autocast_enabled():
+                target_dtype = (
+                    torch.get_autocast_dtype(device_type)
+                    if hasattr(torch, "get_autocast_dtype")
+                    else torch.get_autocast_gpu_dtype()
+                )
+            # Handle the case where the model is quantized
+            elif hasattr(self.config, "_pre_quantization_dtype"):
+                target_dtype = self.config._pre_quantization_dtype
+            else:
+                target_dtype = self.q_proj.weight.dtype
+
+            logger.warning_once(
+                f"The input hidden states seems to be silently casted in float32, this might be related to"
+                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+                f" {target_dtype}."
+            )
+
+            query_states = query_states.to(target_dtype)
+            key_states = key_states.to(target_dtype)
+            value_states = value_states.to(target_dtype)
+
+        value_states1, value_states2 = torch.chunk(value_states, 2, dim=2)
+        value_states1 = value_states1.repeat(1, 1, 2, 1)
+        value_states2 = value_states2.repeat(1, 1, 2, 1)
+
+        attn_output1 = _flash_attention_forward(
+            query_states,
+            key_states,
+            value_states1,
+            attention_mask,
+            q_len,
+            position_ids=position_ids,
+            dropout=dropout_rate,
+            sliding_window=getattr(self, "sliding_window", None),
+            use_top_left_mask=self._flash_attn_uses_top_left_mask,
+            is_causal=self.is_causal,
+        )
+
+        attn_output2 = _flash_attention_forward(
+            query_states,
+            key_states,
+            value_states2,
+            attention_mask,
+            q_len,
+            position_ids=position_ids,
+            dropout=dropout_rate,
+            sliding_window=getattr(self, "sliding_window", None),
+            use_top_left_mask=self._flash_attn_uses_top_left_mask,
+            is_causal=self.is_causal,
+        )
+
+        attn_output = torch.cat([attn_output1, attn_output2], dim=-1)
+        attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2)
+
+        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+            query_states.dtype
+        )
+        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+            query_states.dtype
+        )
+        lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+        attn_output = attn_output1 - lambda_full * attn_output2
+        attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+        attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+        attn_output = self.o_proj(attn_output)
+        return attn_output, None
+
+
+class DiffLlamaSdpaAttention(DiffLlamaAttention):
+    """
+    DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+    `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+    SDPA API.
+    """
+
+    # Adapted from DiffLlamaAttention.forward
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_embeddings: tuple[torch.Tensor, torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+        bsz, q_len, _ = hidden_states.size()
+
+        query_states = self.q_proj(hidden_states)
+        key_states = self.k_proj(hidden_states)
+        value_states = self.v_proj(hidden_states)
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        cos, sin = position_embeddings
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_values is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+        value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
+        value_states = value_states.repeat(1, 2, 1, 1)
+
+        causal_mask = attention_mask
+        if attention_mask is not None:
+            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+        # Reference: https://github.com/pytorch/pytorch/issues/112577.
+        if query_states.device.type == "cuda" and causal_mask is not None:
+            query_states = query_states.contiguous()
+            key_states = key_states.contiguous()
+            value_states = value_states.contiguous()
+
+        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+        is_causal = causal_mask is None and q_len > 1
+
+        attn_output = torch.nn.functional.scaled_dot_product_attention(
+            query_states,
+            key_states,
+            value_states,
+            attn_mask=causal_mask,
+            dropout_p=self.attention_dropout if self.training else 0.0,
+            is_causal=is_causal,
+        )
+
+        attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
+
+        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+            query_states.dtype
+        )
+        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+            query_states.dtype
+        )
+        lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+        attn_output = attn_output1 - lambda_full * attn_output2
+        attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.view(bsz, q_len, -1)
+        attn_output = self.o_proj(attn_output)
+        return attn_output, None
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class DiffLlamaRMSNorm(nn.Module):
+    def __init__(self, hidden_size, eps=1e-6):
+        """
+        DiffLlamaRMSNorm is equivalent to T5LayerNorm
+        """
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states):
+        input_dtype = hidden_states.dtype
+        hidden_states = hidden_states.to(torch.float32)
+        variance = hidden_states.pow(2).mean(-1, keepdim=True)
+        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+        return self.weight * hidden_states.to(input_dtype)
+
+    def extra_repr(self):
+        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+DIFFLLAMA_ATTENTION_CLASSES = {
+    "eager": DiffLlamaAttention,
+    "flash_attention_2": DiffLlamaFlashAttention2,
+    "sdpa": DiffLlamaSdpaAttention,
+}
+
+
+class DiffLlamaDecoderLayer(GradientCheckpointingLayer):
+    def __init__(self, config: DiffLlamaConfig, layer_idx: int):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+
+        self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+        self.mlp = DiffLlamaMLP(config)
+        self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        use_cache: Optional[bool] = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> torch.Tensor:
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+        # Self Attention
+        hidden_states, _ = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            cache_position=cache_position,
+            position_embeddings=position_embeddings,
+            **kwargs,
+        )
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+        return hidden_states
+
+
+@auto_docstring
+class DiffLlamaPreTrainedModel(PreTrainedModel):
+    config: DiffLlamaConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["DiffLlamaDecoderLayer"]
+    _skip_keys_device_placement = ["past_key_values"]
+    _supports_flash_attn = True
+    _supports_sdpa = True
+    _supports_flex_attn = False
+
+    _can_compile_fullgraph = True
+    _supports_attention_backend = False
+    _can_record_outputs = {
+        "hidden_states": DiffLlamaDecoderLayer,
+        "attentions": DiffLlamaAttention,
+    }
+
+    def _init_weights(self, module):
+        super()._init_weights(module)
+        if isinstance(module, DiffLlamaAttention):
+            module.lambda_q1.data.normal_(0, self.config.lambda_std_dev)
+            module.lambda_k1.data.normal_(0, self.config.lambda_std_dev)
+            module.lambda_q2.data.normal_(0, self.config.lambda_std_dev)
+            module.lambda_k2.data.normal_(0, self.config.lambda_std_dev)
+
+
+class DiffLlamaRotaryEmbedding(nn.Module):
+    inv_freq: torch.Tensor  # fix linting for `register_buffer`
+
+    def __init__(self, config: DiffLlamaConfig, device=None):
+        super().__init__()
+        # BC: "rope_type" was originally "type"
+        if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+        else:
+            self.rope_type = "default"
+        self.max_seq_len_cached = config.max_position_embeddings
+        self.original_max_seq_len = config.max_position_embeddings
+
+        self.config = config
+        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
+        self.original_inv_freq = self.inv_freq
+
+    @torch.no_grad()
+    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
+    def forward(self, x, position_ids):
+        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+        position_ids_expanded = position_ids[:, None, :].float()
+
+        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+        with torch.autocast(device_type=device_type, enabled=False):  # Force float32
+            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+            emb = torch.cat((freqs, freqs), dim=-1)
+            cos = emb.cos() * self.attention_scaling
+            sin = emb.sin() * self.attention_scaling
+
+        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+@auto_docstring
+class DiffLlamaModel(DiffLlamaPreTrainedModel):
+    def __init__(self, config: DiffLlamaConfig):
+        super().__init__(config)
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+
+        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+        self.layers = nn.ModuleList(
+            [DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+        )
+        self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.rotary_emb = DiffLlamaRotaryEmbedding(config=config)
+        self.gradient_checkpointing = False
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @check_model_inputs
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> BaseModelOutputWithPast:
+        if (input_ids is None) ^ (inputs_embeds is not None):
+            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
+
+        if use_cache and past_key_values is None:
+            past_key_values = DynamicCache(config=self.config)
+
+        if cache_position is None:
+            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+            cache_position: torch.Tensor = torch.arange(
+                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+            )
+
+        if position_ids is None:
+            position_ids = cache_position.unsqueeze(0)
+
+        causal_mask = create_causal_mask(
+            config=self.config,
+            input_embeds=inputs_embeds,
+            attention_mask=attention_mask,
+            cache_position=cache_position,
+            past_key_values=past_key_values,
+            position_ids=position_ids,
+        )
+
+        hidden_states = inputs_embeds
+        position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+            hidden_states = decoder_layer(
+                hidden_states,
+                attention_mask=causal_mask,
+                position_ids=position_ids,
+                past_key_values=past_key_values,
+                cache_position=cache_position,
+                position_embeddings=position_embeddings,
+                **kwargs,
+            )
+
+        hidden_states = self.norm(hidden_states)
+        return BaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=past_key_values,
+        )
+
+
+@auto_docstring
+class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
+    _tied_weights_keys = ["lm_head.weight"]
+    _tp_plan = {"lm_head": "colwise_rep"}
+    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.model = DiffLlamaModel(config)
+        self.vocab_size = config.vocab_size
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @can_return_tuple
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        logits_to_keep: Union[int, torch.Tensor] = 0,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> CausalLMOutputWithPast:
+        r"""
+        Example:
+
+        ```python
+        >>> from transformers import AutoTokenizer, DiffLlamaForCausalLM
+
+        >>> model = DiffLlamaForCausalLM.from_pretrained("google/diffllama-7b")
+        >>> tokenizer = AutoTokenizer.from_pretrained("google/diffllama-7b")
+
+        >>> prompt = "What is your favorite condiment?"
+        >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+        >>> # Generate
+        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+        "What is your favorite condiment?"
+        ```"""
+        outputs: BaseModelOutputWithPast = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            cache_position=cache_position,
+            **kwargs,
+        )
+
+        hidden_states = outputs.last_hidden_state
+        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+        logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+        return CausalLMOutputWithPast(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+class DiffLlamaForSequenceClassification(GenericForSequenceClassification, DiffLlamaPreTrainedModel):
+    pass
+
+
+class DiffLlamaForQuestionAnswering(GenericForQuestionAnswering, DiffLlamaPreTrainedModel):
+    base_model_prefix = "transformer"  # For BC, where `transformer` was used instead of `model`
+
+
+class DiffLlamaForTokenClassification(GenericForTokenClassification, DiffLlamaPreTrainedModel):
+    pass
+
+
+__all__ = [
+    "DiffLlamaPreTrainedModel",
+    "DiffLlamaModel",
+    "DiffLlamaForCausalLM",
+    "DiffLlamaForSequenceClassification",
+    "DiffLlamaForQuestionAnswering",
+    "DiffLlamaForTokenClassification",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/modular_diffllama.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/modular_diffllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..253b99edff0d7a557da404fa680ce8403e22ccf1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/diffllama/modular_diffllama.py
@@ -0,0 +1,447 @@
+# coding=utf-8
+# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on Llama implementations in this library and Microsoft's
+# Differential Transformer implementations.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Optional
+
+import torch
+from torch import nn
+
+from ...cache_utils import Cache, StaticCache
+from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
+from ...modeling_utils import PreTrainedModel
+from ...utils import logging
+from ...utils.deprecation import deprecate_kwarg
+from ..gemma.modeling_gemma import GemmaForCausalLM
+from ..llama.modeling_llama import (
+    LlamaDecoderLayer,
+    LlamaForQuestionAnswering,
+    LlamaForSequenceClassification,
+    LlamaForTokenClassification,
+    LlamaModel,
+    LlamaPreTrainedModel,
+    apply_rotary_pos_emb,
+    repeat_kv,
+)
+from ..mistral.modeling_mistral import MistralMLP
+from .configuration_diffllama import DiffLlamaConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
+_CONFIG_FOR_DOC = "DiffLlamaConfig"
+
+
+class DiffLlamaMLP(MistralMLP):
+    pass
+
+
+def lambda_init_fn(layer_idx):
+    return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)
+
+
+class DiffLlamaAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None):
+        super().__init__()
+        self.config = config
+        self.layer_idx = layer_idx
+        if layer_idx is None:
+            logger.warning_once(
+                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+                "when creating this class."
+            )
+
+        self.attention_dropout = config.attention_dropout
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+        self.num_key_value_heads = config.num_key_value_heads
+        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+        # under this are not used
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rope_theta = config.rope_theta
+        self.is_causal = True
+
+        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+
+        self.lambda_init = lambda_init_fn(layer_idx)
+        self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+        self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+        self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+        self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+        self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_embeddings: tuple[torch.Tensor, torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+        bsz, target_len, _ = hidden_states.size()
+        q_len = target_len
+
+        query_states = self.q_proj(hidden_states)
+        key_states = self.k_proj(hidden_states)
+        value_states = self.v_proj(hidden_states)
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        cos, sin = position_embeddings
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_values is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+        value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
+        value_states = value_states.repeat(1, 2, 1, 1)
+
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attention_mask is not None:  # no matter the length, we just slice it
+            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+            attn_weights = attn_weights + causal_mask
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+            query_states.dtype
+        )
+        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+            query_states.dtype
+        )
+        lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+        attn_output = torch.matmul(attn_weights, value_states)
+        attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
+
+        attn_output = attn_output1 - lambda_full * attn_output2
+        attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.reshape(bsz, q_len, -1)
+        attn_output = self.o_proj(attn_output)
+        return attn_output, attn_weights
+
+
+class DiffLlamaFlashAttention2(DiffLlamaAttention):
+    """
+    DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays
+    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+    flash attention and deal with padding tokens in case the input contains any of them.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+        self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_embeddings: tuple[torch.Tensor, torch.Tensor],
+        attention_mask: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
+    ) -> tuple[torch.Tensor, None]:
+        if isinstance(past_key_values, StaticCache):
+            raise ValueError(
+                "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+                "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+            )
+
+        bsz, q_len, _ = hidden_states.size()
+
+        query_states = self.q_proj(hidden_states)
+        key_states = self.k_proj(hidden_states)
+        value_states = self.v_proj(hidden_states)
+
+        # Flash attention requires the input to have the shape
+        # batch_size x seq_length x head_dim x hidden_dim
+        # therefore we just need to keep the original shape
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        if position_embeddings is None:
+            logger.warning_once(
+                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+                "removed and `position_embeddings` will be mandatory."
+            )
+            cos, sin = self.rotary_emb(value_states, position_ids)
+        else:
+            cos, sin = position_embeddings
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_values is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+        # to be able to avoid many of these transpose/reshape/view.
+        query_states = query_states.transpose(1, 2)
+        key_states = key_states.transpose(1, 2)
+        value_states = value_states.transpose(1, 2)
+
+        dropout_rate = self.attention_dropout if self.training else 0.0
+
+        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+        # therefore the input hidden states gets silently casted in float32. Hence, we need
+        # cast them back in the correct dtype just to be sure everything works as expected.
+        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+        # in fp32. (DiffLlamaRMSNorm handles it correctly)
+
+        input_dtype = query_states.dtype
+        device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
+        if input_dtype == torch.float32:
+            if torch.is_autocast_enabled():
+                target_dtype = (
+                    torch.get_autocast_dtype(device_type)
+                    if hasattr(torch, "get_autocast_dtype")
+                    else torch.get_autocast_gpu_dtype()
+                )
+            # Handle the case where the model is quantized
+            elif hasattr(self.config, "_pre_quantization_dtype"):
+                target_dtype = self.config._pre_quantization_dtype
+            else:
+                target_dtype = self.q_proj.weight.dtype
+
+            logger.warning_once(
+                f"The input hidden states seems to be silently casted in float32, this might be related to"
+                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+                f" {target_dtype}."
+            )
+
+            query_states = query_states.to(target_dtype)
+            key_states = key_states.to(target_dtype)
+            value_states = value_states.to(target_dtype)
+
+        value_states1, value_states2 = torch.chunk(value_states, 2, dim=2)
+        value_states1 = value_states1.repeat(1, 1, 2, 1)
+        value_states2 = value_states2.repeat(1, 1, 2, 1)
+
+        attn_output1 = _flash_attention_forward(
+            query_states,
+            key_states,
+            value_states1,
+            attention_mask,
+            q_len,
+            position_ids=position_ids,
+            dropout=dropout_rate,
+            sliding_window=getattr(self, "sliding_window", None),
+            use_top_left_mask=self._flash_attn_uses_top_left_mask,
+            is_causal=self.is_causal,
+        )
+
+        attn_output2 = _flash_attention_forward(
+            query_states,
+            key_states,
+            value_states2,
+            attention_mask,
+            q_len,
+            position_ids=position_ids,
+            dropout=dropout_rate,
+            sliding_window=getattr(self, "sliding_window", None),
+            use_top_left_mask=self._flash_attn_uses_top_left_mask,
+            is_causal=self.is_causal,
+        )
+
+        attn_output = torch.cat([attn_output1, attn_output2], dim=-1)
+        attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2)
+
+        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+            query_states.dtype
+        )
+        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+            query_states.dtype
+        )
+        lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+        attn_output = attn_output1 - lambda_full * attn_output2
+        attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+        attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+        attn_output = self.o_proj(attn_output)
+        return attn_output, None
+
+
+class DiffLlamaSdpaAttention(DiffLlamaAttention):
+    """
+    DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+    `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+    SDPA API.
+    """
+
+    # Adapted from DiffLlamaAttention.forward
+    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_embeddings: tuple[torch.Tensor, torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+        bsz, q_len, _ = hidden_states.size()
+
+        query_states = self.q_proj(hidden_states)
+        key_states = self.k_proj(hidden_states)
+        value_states = self.v_proj(hidden_states)
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        cos, sin = position_embeddings
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_values is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+        value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
+        value_states = value_states.repeat(1, 2, 1, 1)
+
+        causal_mask = attention_mask
+        if attention_mask is not None:
+            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+        # Reference: https://github.com/pytorch/pytorch/issues/112577.
+        if query_states.device.type == "cuda" and causal_mask is not None:
+            query_states = query_states.contiguous()
+            key_states = key_states.contiguous()
+            value_states = value_states.contiguous()
+
+        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+        is_causal = causal_mask is None and q_len > 1
+
+        attn_output = torch.nn.functional.scaled_dot_product_attention(
+            query_states,
+            key_states,
+            value_states,
+            attn_mask=causal_mask,
+            dropout_p=self.attention_dropout if self.training else 0.0,
+            is_causal=is_causal,
+        )
+
+        attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
+
+        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+            query_states.dtype
+        )
+        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+            query_states.dtype
+        )
+        lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+        attn_output = attn_output1 - lambda_full * attn_output2
+        attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.view(bsz, q_len, -1)
+        attn_output = self.o_proj(attn_output)
+        return attn_output, None
+
+
+DIFFLLAMA_ATTENTION_CLASSES = {
+    "eager": DiffLlamaAttention,
+    "flash_attention_2": DiffLlamaFlashAttention2,
+    "sdpa": DiffLlamaSdpaAttention,
+}
+
+
+class DiffLlamaDecoderLayer(LlamaDecoderLayer):
+    def __init__(self, config: DiffLlamaConfig, layer_idx: int):
+        super().__init__(config, layer_idx)
+
+        self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+
+class DiffLlamaPreTrainedModel(LlamaPreTrainedModel):
+    _supports_flex_attn = False
+    _supports_attention_backend = False
+
+    def _init_weights(self, module):
+        PreTrainedModel._init_weights(self, module)
+        if isinstance(module, DiffLlamaAttention):
+            module.lambda_q1.data.normal_(0, self.config.lambda_std_dev)
+            module.lambda_k1.data.normal_(0, self.config.lambda_std_dev)
+            module.lambda_q2.data.normal_(0, self.config.lambda_std_dev)
+            module.lambda_k2.data.normal_(0, self.config.lambda_std_dev)
+
+
+class DiffLlamaModel(LlamaModel):
+    pass
+
+
+class DiffLlamaForCausalLM(GemmaForCausalLM):
+    pass
+
+
+class DiffLlamaForSequenceClassification(LlamaForSequenceClassification):
+    pass
+
+
+class DiffLlamaForQuestionAnswering(LlamaForQuestionAnswering):
+    pass
+
+
+class DiffLlamaForTokenClassification(LlamaForTokenClassification):
+    pass
+
+
+__all__ = [
+    "DiffLlamaPreTrainedModel",
+    "DiffLlamaModel",
+    "DiffLlamaForCausalLM",
+    "DiffLlamaForSequenceClassification",
+    "DiffLlamaForQuestionAnswering",
+    "DiffLlamaForTokenClassification",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b64cdbb3c7eb0467f6112225b8c0d9e1f65f9e99
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_dinat import *
+    from .modeling_dinat import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/configuration_dinat.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/configuration_dinat.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7d7fa509c5a3b2f5efc3b936cf1761b4ab0e107
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/configuration_dinat.py
@@ -0,0 +1,152 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Dilated Neighborhood Attention Transformer model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class DinatConfig(BackboneConfigMixin, PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`DinatModel`]. It is used to instantiate a Dinat
+    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+    defaults will yield a similar configuration to that of the Dinat
+    [shi-labs/dinat-mini-in1k-224](https://huggingface.co/shi-labs/dinat-mini-in1k-224) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        patch_size (`int`, *optional*, defaults to 4):
+            The size (resolution) of each patch. NOTE: Only patch size of 4 is supported at the moment.
+        num_channels (`int`, *optional*, defaults to 3):
+            The number of input channels.
+        embed_dim (`int`, *optional*, defaults to 64):
+            Dimensionality of patch embedding.
+        depths (`list[int]`, *optional*, defaults to `[3, 4, 6, 5]`):
+            Number of layers in each level of the encoder.
+        num_heads (`list[int]`, *optional*, defaults to `[2, 4, 8, 16]`):
+            Number of attention heads in each layer of the Transformer encoder.
+        kernel_size (`int`, *optional*, defaults to 7):
+            Neighborhood Attention kernel size.
+        dilations (`list[list[int]]`, *optional*, defaults to `[[1, 8, 1], [1, 4, 1, 4], [1, 2, 1, 2, 1, 2], [1, 1, 1, 1, 1]]`):
+            Dilation value of each NA layer in the Transformer encoder.
+        mlp_ratio (`float`, *optional*, defaults to 3.0):
+            Ratio of MLP hidden dimensionality to embedding dimensionality.
+        qkv_bias (`bool`, *optional*, defaults to `True`):
+            Whether or not a learnable bias should be added to the queries, keys and values.
+        hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout probability for all fully connected layers in the embeddings and encoder.
+        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        drop_path_rate (`float`, *optional*, defaults to 0.1):
+            Stochastic depth rate.
+        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
+            `"selu"` and `"gelu_new"` are supported.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+            The epsilon used by the layer normalization layers.
+        layer_scale_init_value (`float`, *optional*, defaults to 0.0):
+            The initial value for the layer scale. Disabled if <=0.
+        out_features (`list[str]`, *optional*):
+            If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+            (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+            corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+            same order as defined in the `stage_names` attribute.
+        out_indices (`list[int]`, *optional*):
+            If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+            many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+            If unset and `out_features` is unset, will default to the last stage. Must be in the
+            same order as defined in the `stage_names` attribute.
+
+    Example:
+
+    ```python
+    >>> from transformers import DinatConfig, DinatModel
+
+    >>> # Initializing a Dinat shi-labs/dinat-mini-in1k-224 style configuration
+    >>> configuration = DinatConfig()
+
+    >>> # Initializing a model (with random weights) from the shi-labs/dinat-mini-in1k-224 style configuration
+    >>> model = DinatModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "dinat"
+
+    attribute_map = {
+        "num_attention_heads": "num_heads",
+        "num_hidden_layers": "num_layers",
+    }
+
+    def __init__(
+        self,
+        patch_size=4,
+        num_channels=3,
+        embed_dim=64,
+        depths=[3, 4, 6, 5],
+        num_heads=[2, 4, 8, 16],
+        kernel_size=7,
+        dilations=[[1, 8, 1], [1, 4, 1, 4], [1, 2, 1, 2, 1, 2], [1, 1, 1, 1, 1]],
+        mlp_ratio=3.0,
+        qkv_bias=True,
+        hidden_dropout_prob=0.0,
+        attention_probs_dropout_prob=0.0,
+        drop_path_rate=0.1,
+        hidden_act="gelu",
+        initializer_range=0.02,
+        layer_norm_eps=1e-5,
+        layer_scale_init_value=0.0,
+        out_features=None,
+        out_indices=None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.embed_dim = embed_dim
+        self.depths = depths
+        self.num_layers = len(depths)
+        self.num_heads = num_heads
+        self.kernel_size = kernel_size
+        self.dilations = dilations
+        self.mlp_ratio = mlp_ratio
+        self.qkv_bias = qkv_bias
+        self.hidden_dropout_prob = hidden_dropout_prob
+        self.attention_probs_dropout_prob = attention_probs_dropout_prob
+        self.drop_path_rate = drop_path_rate
+        self.hidden_act = hidden_act
+        self.layer_norm_eps = layer_norm_eps
+        self.initializer_range = initializer_range
+        # we set the hidden_size attribute in order to make Dinat work with VisionEncoderDecoderModel
+        # this indicates the channel dimension after the last stage of the model
+        self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
+        self.layer_scale_init_value = layer_scale_init_value
+        self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
+        self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+            out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+        )
+
+
+__all__ = ["DinatConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/modeling_dinat.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/modeling_dinat.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b7ec37b0ea8489e5db87df22c1876fd4548fe86
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinat/modeling_dinat.py
@@ -0,0 +1,855 @@
+# coding=utf-8
+# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Dilated Neighborhood Attention Transformer model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BackboneOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+    ModelOutput,
+    OptionalDependencyNotAvailable,
+    auto_docstring,
+    is_natten_available,
+    logging,
+    requires_backends,
+)
+from ...utils.backbone_utils import BackboneMixin
+from .configuration_dinat import DinatConfig
+
+
+if is_natten_available():
+    from natten.functional import natten2dav, natten2dqkrpb
+else:
+
+    def natten2dqkrpb(*args, **kwargs):
+        raise OptionalDependencyNotAvailable()
+
+    def natten2dav(*args, **kwargs):
+        raise OptionalDependencyNotAvailable()
+
+
+logger = logging.get_logger(__name__)
+
+
+# drop_path and DinatDropPath are from the timm library.
+
+
+@dataclass
+@auto_docstring(
+    custom_intro="""
+    Dinat encoder's outputs, with potential hidden states and attentions.
+    """
+)
+class DinatEncoderOutput(ModelOutput):
+    r"""
+    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+        shape `(batch_size, hidden_size, height, width)`.
+
+        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+        include the spatial dimensions.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+    custom_intro="""
+    Dinat model's outputs that also contains a pooling of the last hidden states.
+    """
+)
+class DinatModelOutput(ModelOutput):
+    r"""
+    pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
+        Average pooling of the last layer hidden-state.
+    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+        shape `(batch_size, hidden_size, height, width)`.
+
+        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+        include the spatial dimensions.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    pooler_output: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+    custom_intro="""
+    Dinat outputs for image classification.
+    """
+)
+class DinatImageClassifierOutput(ModelOutput):
+    r"""
+    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+        Classification (or regression if config.num_labels==1) loss.
+    logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+        Classification (or regression if config.num_labels==1) scores (before SoftMax).
+    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+        shape `(batch_size, hidden_size, height, width)`.
+
+        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+        include the spatial dimensions.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+class DinatEmbeddings(nn.Module):
+    """
+    Construct the patch and position embeddings.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+
+        self.patch_embeddings = DinatPatchEmbeddings(config)
+
+        self.norm = nn.LayerNorm(config.embed_dim)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, pixel_values: Optional[torch.FloatTensor]) -> tuple[torch.Tensor]:
+        embeddings = self.patch_embeddings(pixel_values)
+        embeddings = self.norm(embeddings)
+
+        embeddings = self.dropout(embeddings)
+
+        return embeddings
+
+
+class DinatPatchEmbeddings(nn.Module):
+    """
+    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+    `hidden_states` (patch embeddings) of shape `(batch_size, height, width, hidden_size)` to be consumed by a
+    Transformer.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        patch_size = config.patch_size
+        num_channels, hidden_size = config.num_channels, config.embed_dim
+        self.num_channels = num_channels
+
+        if patch_size == 4:
+            pass
+        else:
+            # TODO: Support arbitrary patch sizes.
+            raise ValueError("Dinat only supports patch size of 4 at the moment.")
+
+        self.projection = nn.Sequential(
+            nn.Conv2d(self.num_channels, hidden_size // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
+            nn.Conv2d(hidden_size // 2, hidden_size, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
+        )
+
+    def forward(self, pixel_values: Optional[torch.FloatTensor]) -> torch.Tensor:
+        _, num_channels, height, width = pixel_values.shape
+        if num_channels != self.num_channels:
+            raise ValueError(
+                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+            )
+        embeddings = self.projection(pixel_values)
+        embeddings = embeddings.permute(0, 2, 3, 1)
+
+        return embeddings
+
+
+class DinatDownsampler(nn.Module):
+    """
+    Convolutional Downsampling Layer.
+
+    Args:
+        dim (`int`):
+            Number of input channels.
+        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
+            Normalization layer class.
+    """
+
+    def __init__(self, dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
+        super().__init__()
+        self.dim = dim
+        self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
+        self.norm = norm_layer(2 * dim)
+
+    def forward(self, input_feature: torch.Tensor) -> torch.Tensor:
+        input_feature = self.reduction(input_feature.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+        input_feature = self.norm(input_feature)
+        return input_feature
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+    """
+    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+    argument.
+    """
+    if drop_prob == 0.0 or not training:
+        return input
+    keep_prob = 1 - drop_prob
+    shape = (input.shape[0],) + (1,) * (input.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+    random_tensor.floor_()  # binarize
+    output = input.div(keep_prob) * random_tensor
+    return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Dinat
+class DinatDropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+    def __init__(self, drop_prob: Optional[float] = None) -> None:
+        super().__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        return drop_path(hidden_states, self.drop_prob, self.training)
+
+    def extra_repr(self) -> str:
+        return f"p={self.drop_prob}"
+
+
+class NeighborhoodAttention(nn.Module):
+    def __init__(self, config, dim, num_heads, kernel_size, dilation):
+        super().__init__()
+        if dim % num_heads != 0:
+            raise ValueError(
+                f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
+            )
+
+        self.num_attention_heads = num_heads
+        self.attention_head_size = int(dim / num_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+        self.kernel_size = kernel_size
+        self.dilation = dilation
+
+        # rpb is learnable relative positional biases; same concept is used Swin.
+        self.rpb = nn.Parameter(torch.zeros(num_heads, (2 * self.kernel_size - 1), (2 * self.kernel_size - 1)))
+
+        self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+        self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+        self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        output_attentions: Optional[bool] = False,
+    ) -> tuple[torch.Tensor]:
+        batch_size, seq_length, _ = hidden_states.shape
+        query_layer = (
+            self.query(hidden_states)
+            .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+            .transpose(1, 2)
+        )
+        key_layer = (
+            self.key(hidden_states)
+            .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+            .transpose(1, 2)
+        )
+        value_layer = (
+            self.value(hidden_states)
+            .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
+            .transpose(1, 2)
+        )
+
+        # Apply the scale factor before computing attention weights. It's usually more efficient because
+        # attention weights are typically a bigger tensor compared to query.
+        # It gives identical results because scalars are commutable in matrix multiplication.
+        query_layer = query_layer / math.sqrt(self.attention_head_size)
+
+        # Compute NA between "query" and "key" to get the raw attention scores, and add relative positional biases.
+        attention_scores = natten2dqkrpb(query_layer, key_layer, self.rpb, self.kernel_size, self.dilation)
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs = self.dropout(attention_probs)
+
+        context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, self.dilation)
+        context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+
+        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+        return outputs
+
+
+class NeighborhoodAttentionOutput(nn.Module):
+    def __init__(self, config, dim):
+        super().__init__()
+        self.dense = nn.Linear(dim, dim)
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+
+        return hidden_states
+
+
+class NeighborhoodAttentionModule(nn.Module):
+    def __init__(self, config, dim, num_heads, kernel_size, dilation):
+        super().__init__()
+        self.self = NeighborhoodAttention(config, dim, num_heads, kernel_size, dilation)
+        self.output = NeighborhoodAttentionOutput(config, dim)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+        )
+
+        # Prune linear layers
+        self.self.query = prune_linear_layer(self.self.query, index)
+        self.self.key = prune_linear_layer(self.self.key, index)
+        self.self.value = prune_linear_layer(self.self.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        output_attentions: Optional[bool] = False,
+    ) -> tuple[torch.Tensor]:
+        self_outputs = self.self(hidden_states, output_attentions)
+        attention_output = self.output(self_outputs[0], hidden_states)
+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
+        return outputs
+
+
+class DinatIntermediate(nn.Module):
+    def __init__(self, config, dim):
+        super().__init__()
+        self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+        return hidden_states
+
+
+class DinatOutput(nn.Module):
+    def __init__(self, config, dim):
+        super().__init__()
+        self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        return hidden_states
+
+
+class DinatLayer(nn.Module):
+    def __init__(self, config, dim, num_heads, dilation, drop_path_rate=0.0):
+        super().__init__()
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.kernel_size = config.kernel_size
+        self.dilation = dilation
+        self.window_size = self.kernel_size * self.dilation
+        self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+        self.attention = NeighborhoodAttentionModule(
+            config, dim, num_heads, kernel_size=self.kernel_size, dilation=self.dilation
+        )
+        self.drop_path = DinatDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
+        self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+        self.intermediate = DinatIntermediate(config, dim)
+        self.output = DinatOutput(config, dim)
+        self.layer_scale_parameters = (
+            nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True)
+            if config.layer_scale_init_value > 0
+            else None
+        )
+
+    def maybe_pad(self, hidden_states, height, width):
+        window_size = self.window_size
+        pad_values = (0, 0, 0, 0, 0, 0)
+        if height < window_size or width < window_size:
+            pad_l = pad_t = 0
+            pad_r = max(0, window_size - width)
+            pad_b = max(0, window_size - height)
+            pad_values = (0, 0, pad_l, pad_r, pad_t, pad_b)
+            hidden_states = nn.functional.pad(hidden_states, pad_values)
+        return hidden_states, pad_values
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        output_attentions: Optional[bool] = False,
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        batch_size, height, width, channels = hidden_states.size()
+        shortcut = hidden_states
+
+        hidden_states = self.layernorm_before(hidden_states)
+        # pad hidden_states if they are smaller than kernel size x dilation
+        hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
+
+        _, height_pad, width_pad, _ = hidden_states.shape
+
+        attention_outputs = self.attention(hidden_states, output_attentions=output_attentions)
+
+        attention_output = attention_outputs[0]
+
+        was_padded = pad_values[3] > 0 or pad_values[5] > 0
+        if was_padded:
+            attention_output = attention_output[:, :height, :width, :].contiguous()
+
+        if self.layer_scale_parameters is not None:
+            attention_output = self.layer_scale_parameters[0] * attention_output
+
+        hidden_states = shortcut + self.drop_path(attention_output)
+
+        layer_output = self.layernorm_after(hidden_states)
+        layer_output = self.output(self.intermediate(layer_output))
+
+        if self.layer_scale_parameters is not None:
+            layer_output = self.layer_scale_parameters[1] * layer_output
+
+        layer_output = hidden_states + self.drop_path(layer_output)
+
+        layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
+        return layer_outputs
+
+
+class DinatStage(nn.Module):
+    def __init__(self, config, dim, depth, num_heads, dilations, drop_path_rate, downsample):
+        super().__init__()
+        self.config = config
+        self.dim = dim
+        self.layers = nn.ModuleList(
+            [
+                DinatLayer(
+                    config=config,
+                    dim=dim,
+                    num_heads=num_heads,
+                    dilation=dilations[i],
+                    drop_path_rate=drop_path_rate[i],
+                )
+                for i in range(depth)
+            ]
+        )
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm)
+        else:
+            self.downsample = None
+
+        self.pointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        output_attentions: Optional[bool] = False,
+    ) -> tuple[torch.Tensor]:
+        _, height, width, _ = hidden_states.size()
+        for i, layer_module in enumerate(self.layers):
+            layer_outputs = layer_module(hidden_states, output_attentions)
+            hidden_states = layer_outputs[0]
+
+        hidden_states_before_downsampling = hidden_states
+        if self.downsample is not None:
+            hidden_states = self.downsample(hidden_states_before_downsampling)
+
+        stage_outputs = (hidden_states, hidden_states_before_downsampling)
+
+        if output_attentions:
+            stage_outputs += layer_outputs[1:]
+        return stage_outputs
+
+
+class DinatEncoder(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.num_levels = len(config.depths)
+        self.config = config
+        dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
+        self.levels = nn.ModuleList(
+            [
+                DinatStage(
+                    config=config,
+                    dim=int(config.embed_dim * 2**i_layer),
+                    depth=config.depths[i_layer],
+                    num_heads=config.num_heads[i_layer],
+                    dilations=config.dilations[i_layer],
+                    drop_path_rate=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
+                    downsample=DinatDownsampler if (i_layer < self.num_levels - 1) else None,
+                )
+                for i_layer in range(self.num_levels)
+            ]
+        )
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        output_attentions: Optional[bool] = False,
+        output_hidden_states: Optional[bool] = False,
+        output_hidden_states_before_downsampling: Optional[bool] = False,
+        return_dict: Optional[bool] = True,
+    ) -> Union[tuple, DinatEncoderOutput]:
+        all_hidden_states = () if output_hidden_states else None
+        all_reshaped_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+
+        if output_hidden_states:
+            # rearrange b h w c -> b c h w
+            reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
+            all_hidden_states += (hidden_states,)
+            all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+        for i, layer_module in enumerate(self.levels):
+            layer_outputs = layer_module(hidden_states, output_attentions)
+
+            hidden_states = layer_outputs[0]
+            hidden_states_before_downsampling = layer_outputs[1]
+
+            if output_hidden_states and output_hidden_states_before_downsampling:
+                # rearrange b h w c -> b c h w
+                reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2)
+                all_hidden_states += (hidden_states_before_downsampling,)
+                all_reshaped_hidden_states += (reshaped_hidden_state,)
+            elif output_hidden_states and not output_hidden_states_before_downsampling:
+                # rearrange b h w c -> b c h w
+                reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
+                all_hidden_states += (hidden_states,)
+                all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+            if output_attentions:
+                all_self_attentions += layer_outputs[2:]
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+        return DinatEncoderOutput(
+            last_hidden_state=hidden_states,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+            reshaped_hidden_states=all_reshaped_hidden_states,
+        )
+
+
+@auto_docstring
+class DinatPreTrainedModel(PreTrainedModel):
+    config: DinatConfig
+    base_model_prefix = "dinat"
+    main_input_name = "pixel_values"
+
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class DinatModel(DinatPreTrainedModel):
+    def __init__(self, config, add_pooling_layer=True):
+        r"""
+        add_pooling_layer (bool, *optional*, defaults to `True`):
+            Whether to add a pooling layer
+        """
+        super().__init__(config)
+
+        requires_backends(self, ["natten"])
+
+        self.config = config
+        self.num_levels = len(config.depths)
+        self.num_features = int(config.embed_dim * 2 ** (self.num_levels - 1))
+
+        self.embeddings = DinatEmbeddings(config)
+        self.encoder = DinatEncoder(config)
+
+        self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
+        self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embeddings.patch_embeddings
+
+    def _prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    @auto_docstring
+    def forward(
+        self,
+        pixel_values: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[tuple, DinatModelOutput]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        embedding_output = self.embeddings(pixel_values)
+
+        encoder_outputs = self.encoder(
+            embedding_output,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        sequence_output = encoder_outputs[0]
+        sequence_output = self.layernorm(sequence_output)
+
+        pooled_output = None
+        if self.pooler is not None:
+            pooled_output = self.pooler(sequence_output.flatten(1, 2).transpose(1, 2))
+            pooled_output = torch.flatten(pooled_output, 1)
+
+        if not return_dict:
+            output = (sequence_output, pooled_output) + encoder_outputs[1:]
+
+            return output
+
+        return DinatModelOutput(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+            reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
+        )
+
+
+@auto_docstring(
+    custom_intro="""
+    Dinat Model transformer with an image classification head on top (a linear layer on top of the final hidden state
+    of the [CLS] token) e.g. for ImageNet.
+    """
+)
+class DinatForImageClassification(DinatPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+
+        requires_backends(self, ["natten"])
+
+        self.num_labels = config.num_labels
+        self.dinat = DinatModel(config)
+
+        # Classifier head
+        self.classifier = (
+            nn.Linear(self.dinat.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
+        )
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @auto_docstring
+    def forward(
+        self,
+        pixel_values: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[tuple, DinatImageClassifierOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.dinat(
+            pixel_values,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        pooled_output = outputs[1]
+
+        logits = self.classifier(pooled_output)
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(labels, logits, self.config)
+
+        if not return_dict:
+            output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return DinatImageClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            reshaped_hidden_states=outputs.reshaped_hidden_states,
+        )
+
+
+@auto_docstring(
+    custom_intro="""
+    NAT backbone, to be used with frameworks like DETR and MaskFormer.
+    """
+)
+class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
+    def __init__(self, config):
+        super().__init__(config)
+        super()._init_backbone(config)
+
+        requires_backends(self, ["natten"])
+
+        self.embeddings = DinatEmbeddings(config)
+        self.encoder = DinatEncoder(config)
+        self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
+
+        # Add layer norms to hidden states of out_features
+        hidden_states_norms = {}
+        for stage, num_channels in zip(self._out_features, self.channels):
+            hidden_states_norms[stage] = nn.LayerNorm(num_channels)
+        self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embeddings.patch_embeddings
+
+    @auto_docstring
+    def forward(
+        self,
+        pixel_values: torch.Tensor,
+        output_hidden_states: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> BackboneOutput:
+        r"""
+        Examples:
+
+        ```python
+        >>> from transformers import AutoImageProcessor, AutoBackbone
+        >>> import torch
+        >>> from PIL import Image
+        >>> import requests
+
+        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw)
+
+        >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
+        >>> model = AutoBackbone.from_pretrained(
+        ...     "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"]
+        ... )
+
+        >>> inputs = processor(image, return_tensors="pt")
+
+        >>> outputs = model(**inputs)
+
+        >>> feature_maps = outputs.feature_maps
+        >>> list(feature_maps[-1].shape)
+        [1, 512, 7, 7]
+        ```"""
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+        embedding_output = self.embeddings(pixel_values)
+
+        outputs = self.encoder(
+            embedding_output,
+            output_attentions=output_attentions,
+            output_hidden_states=True,
+            output_hidden_states_before_downsampling=True,
+            return_dict=True,
+        )
+
+        hidden_states = outputs.reshaped_hidden_states
+
+        feature_maps = ()
+        for stage, hidden_state in zip(self.stage_names, hidden_states):
+            if stage in self.out_features:
+                batch_size, num_channels, height, width = hidden_state.shape
+                hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
+                hidden_state = hidden_state.view(batch_size, height * width, num_channels)
+                hidden_state = self.hidden_states_norms[stage](hidden_state)
+                hidden_state = hidden_state.view(batch_size, height, width, num_channels)
+                hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+                feature_maps += (hidden_state,)
+
+        if not return_dict:
+            output = (feature_maps,)
+            if output_hidden_states:
+                output += (outputs.hidden_states,)
+            return output
+
+        return BackboneOutput(
+            feature_maps=feature_maps,
+            hidden_states=outputs.hidden_states if output_hidden_states else None,
+            attentions=outputs.attentions,
+        )
+
+
+__all__ = ["DinatForImageClassification", "DinatModel", "DinatPreTrainedModel", "DinatBackbone"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d10027b6a3b6375235a6785df044e8f0ce5fb33
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_dinov2_with_registers import *
+    from .modeling_dinov2_with_registers import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec4f446fc684f40d634927c1e7a52b64c5732b12
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py
@@ -0,0 +1,159 @@
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+#           This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py.
+#               Do NOT edit this file manually as any edits will be overwritten by the generation of
+#             the file from the modular. If any change should be done, please apply the change to the
+#                          modular_dinov2_with_registers.py file directly. One of our CI enforces this.
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an
+    Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration
+    with the defaults will yield a similar configuration to that of the DINOv2 with Registers
+    [facebook/dinov2-with-registers-base](https://huggingface.co/facebook/dinov2-with-registers-base) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        hidden_size (`int`, *optional*, defaults to 768):
+            Dimensionality of the encoder layers and the pooler layer.
+        num_hidden_layers (`int`, *optional*, defaults to 12):
+            Number of hidden layers in the Transformer encoder.
+        num_attention_heads (`int`, *optional*, defaults to 12):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        mlp_ratio (`int`, *optional*, defaults to 4):
+            Ratio of the hidden size of the MLPs relative to the `hidden_size`.
+        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"selu"` and `"gelu_new"` are supported.
+        hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+            The epsilon used by the layer normalization layers.
+        image_size (`int`, *optional*, defaults to 224):
+            The size (resolution) of each image.
+        patch_size (`int`, *optional*, defaults to 16):
+            The size (resolution) of each patch.
+        num_channels (`int`, *optional*, defaults to 3):
+            The number of input channels.
+        qkv_bias (`bool`, *optional*, defaults to `True`):
+            Whether to add a bias to the queries, keys and values.
+        layerscale_value (`float`, *optional*, defaults to 1.0):
+           Initial value to use for layer scale.
+        drop_path_rate (`float`, *optional*, defaults to 0.0):
+            Stochastic depth rate per sample (when applied in the main path of residual layers).
+        use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
+            Whether to use the SwiGLU feedforward neural network.
+        num_register_tokens (`int`, *optional*, defaults to 4):
+            Number of register tokens to use.
+        out_features (`list[str]`, *optional*):
+            If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+            (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+            corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+            same order as defined in the `stage_names` attribute.
+        out_indices (`list[int]`, *optional*):
+            If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+            many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+            If unset and `out_features` is unset, will default to the last stage. Must be in the
+            same order as defined in the `stage_names` attribute.
+        apply_layernorm (`bool`, *optional*, defaults to `True`):
+            Whether to apply layer normalization to the feature maps in case the model is used as backbone.
+        reshape_hidden_states (`bool`, *optional*, defaults to `True`):
+            Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
+            case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
+            seq_len, hidden_size)`.
+
+    Example:
+
+    ```python
+    >>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel
+
+    >>> # Initializing a Dinov2WithRegisters base style configuration
+    >>> configuration = Dinov2WithRegistersConfig()
+
+    >>> # Initializing a model (with random weights) from the base style configuration
+    >>> model = Dinov2WithRegistersModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "dinov2_with_registers"
+
+    def __init__(
+        self,
+        hidden_size=768,
+        num_hidden_layers=12,
+        num_attention_heads=12,
+        mlp_ratio=4,
+        hidden_act="gelu",
+        hidden_dropout_prob=0.0,
+        attention_probs_dropout_prob=0.0,
+        initializer_range=0.02,
+        layer_norm_eps=1e-6,
+        image_size=224,
+        patch_size=16,
+        num_channels=3,
+        qkv_bias=True,
+        layerscale_value=1.0,
+        drop_path_rate=0.0,
+        use_swiglu_ffn=False,
+        num_register_tokens=4,
+        out_features=None,
+        out_indices=None,
+        apply_layernorm=True,
+        reshape_hidden_states=True,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.hidden_size = hidden_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.mlp_ratio = mlp_ratio
+        self.hidden_act = hidden_act
+        self.hidden_dropout_prob = hidden_dropout_prob
+        self.attention_probs_dropout_prob = attention_probs_dropout_prob
+        self.initializer_range = initializer_range
+        self.layer_norm_eps = layer_norm_eps
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.qkv_bias = qkv_bias
+        self.layerscale_value = layerscale_value
+        self.drop_path_rate = drop_path_rate
+        self.use_swiglu_ffn = use_swiglu_ffn
+        self.num_register_tokens = num_register_tokens
+        self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
+        self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+            out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+        )
+        self.apply_layernorm = apply_layernorm
+        self.reshape_hidden_states = reshape_hidden_states
+
+
+__all__ = ["Dinov2WithRegistersConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c6f4c335f58b2ddd37bf4042d5f8e51c474cee9
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py
@@ -0,0 +1,712 @@
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+#           This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py.
+#               Do NOT edit this file manually as any edits will be overwritten by the generation of
+#             the file from the modular. If any change should be done, please apply the change to the
+#                          modular_dinov2_with_registers.py file directly. One of our CI enforces this.
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections.abc
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import TransformersKwargs, auto_docstring, torch_int
+from ...utils.backbone_utils import BackboneMixin
+from ...utils.generic import can_return_tuple, check_model_inputs
+from .configuration_dinov2_with_registers import Dinov2WithRegistersConfig
+
+
+class Dinov2WithRegistersPatchEmbeddings(nn.Module):
+    """
+    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+    Transformer.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        image_size, patch_size = config.image_size, config.patch_size
+        num_channels, hidden_size = config.num_channels, config.hidden_size
+
+        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.num_patches = num_patches
+
+        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+        num_channels = pixel_values.shape[1]
+        if num_channels != self.num_channels:
+            raise ValueError(
+                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+                f" Expected {self.num_channels} but got {num_channels}."
+            )
+        embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+        return embeddings
+
+
+class Dinov2WithRegistersEmbeddings(nn.Module):
+    """
+    Construct the CLS token, mask token, register tokens, position and patch embeddings.
+    """
+
+    def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+        super().__init__()
+
+        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+        self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
+        self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
+        self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config)
+        num_patches = self.patch_embeddings.num_patches
+        self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+        self.patch_size = config.patch_size
+        self.config = config
+
+    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+        """
+        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+        resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
+        with the original implementation.
+
+        Adapted from:
+        - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+        - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
+        """
+        num_patches = embeddings.shape[1] - 1
+        num_positions = self.position_embeddings.shape[1] - 1
+
+        # Skip interpolation for matching dimensions (unless tracing)
+        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+            return self.position_embeddings
+
+        # Handle class token and patch embeddings separately
+        class_pos_embed = self.position_embeddings[:, 0]
+        patch_pos_embed = self.position_embeddings[:, 1:]
+        dim = embeddings.shape[-1]
+
+        # Calculate new dimensions
+        height = height // self.config.patch_size
+        width = width // self.config.patch_size
+
+        # Reshape for interpolation
+        sqrt_num_positions = torch_int(num_positions**0.5)
+        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+        # Store original dtype for restoration after interpolation
+        target_dtype = patch_pos_embed.dtype
+
+        # Interpolate at float32 precision
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed.to(dtype=torch.float32),
+            size=(torch_int(height), torch_int(width)),  # Explicit size instead of scale_factor
+            mode="bicubic",
+            align_corners=False,
+            antialias=True,
+        ).to(dtype=target_dtype)
+
+        # Validate output dimensions if not tracing
+        if not torch.jit.is_tracing():
+            if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
+                raise ValueError("Width or height does not match with the interpolated position embeddings")
+
+        # Reshape back to original format
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+        # Combine class and patch embeddings
+        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+    def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
+        batch_size, _, height, width = pixel_values.shape
+        target_dtype = self.patch_embeddings.projection.weight.dtype
+        embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
+
+        if bool_masked_pos is not None:
+            embeddings = torch.where(
+                bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
+            )
+
+        # add the [CLS] token to the embedded patch tokens
+        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+        embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+        # add positional encoding to each token
+        embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+
+        # add register tokens
+        embeddings = torch.cat(
+            (embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1), embeddings[:, 1:]), dim=1
+        )
+
+        embeddings = self.dropout(embeddings)
+
+        return embeddings
+
+
+def eager_attention_forward(
+    module: nn.Module,
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attention_mask: Optional[torch.Tensor],
+    scaling: float,
+    dropout: float = 0.0,
+    **kwargs,
+):
+    # Take the dot product between "query" and "key" to get the raw attention scores.
+    attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+
+    # Normalize the attention scores to probabilities.
+    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+
+    # This is actually dropping out entire tokens to attend to, which might
+    # seem a bit unusual, but is taken from the original Transformer paper.
+    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+    # Mask heads if we want to
+    if attention_mask is not None:
+        attn_weights = attn_weights * attention_mask
+
+    attn_output = torch.matmul(attn_weights, value)
+    attn_output = attn_output.transpose(1, 2).contiguous()
+
+    return attn_output, attn_weights
+
+
+class Dinov2WithRegistersSelfAttention(nn.Module):
+    def __init__(self, config: Dinov2WithRegistersConfig):
+        super().__init__()
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+            raise ValueError(
+                f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
+                f"heads {config.num_attention_heads}."
+            )
+
+        self.config = config
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+        self.dropout_prob = config.attention_probs_dropout_prob
+        self.scaling = self.attention_head_size**-0.5
+        self.is_causal = False
+
+        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+    def forward(
+        self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        batch_size = hidden_states.shape[0]
+        new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
+
+        key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
+        value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
+        query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
+
+        attention_interface: Callable = eager_attention_forward
+        if self.config._attn_implementation != "eager":
+            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+        context_layer, attention_probs = attention_interface(
+            self,
+            query_layer,
+            key_layer,
+            value_layer,
+            head_mask,
+            is_causal=self.is_causal,
+            scaling=self.scaling,
+            dropout=0.0 if not self.training else self.dropout_prob,
+        )
+
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.reshape(new_context_layer_shape)
+
+        return context_layer, attention_probs
+
+
+class Dinov2WithRegistersSelfOutput(nn.Module):
+    """
+    The residual connection is defined in Dinov2WithRegistersLayer instead of here (as is the case with other models), due to the
+    layernorm applied before each block.
+    """
+
+    def __init__(self, config: Dinov2WithRegistersConfig):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        return hidden_states
+
+
+class Dinov2WithRegistersAttention(nn.Module):
+    def __init__(self, config: Dinov2WithRegistersConfig):
+        super().__init__()
+        self.attention = Dinov2WithRegistersSelfAttention(config)
+        self.output = Dinov2WithRegistersSelfOutput(config)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads: set[int]):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+        )
+
+        # Prune linear layers
+        self.attention.query = prune_linear_layer(self.attention.query, index)
+        self.attention.key = prune_linear_layer(self.attention.key, index)
+        self.attention.value = prune_linear_layer(self.attention.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+        self_attn_output, _ = self.attention(hidden_states, head_mask)
+        output = self.output(self_attn_output, hidden_states)
+        return output
+
+
+class Dinov2WithRegistersLayerScale(nn.Module):
+    def __init__(self, config) -> None:
+        super().__init__()
+        self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
+
+    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+        return hidden_state * self.lambda1
+
+
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+    """
+    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+    argument.
+    """
+    if drop_prob == 0.0 or not training:
+        return input
+    keep_prob = 1 - drop_prob
+    shape = (input.shape[0],) + (1,) * (input.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+    random_tensor.floor_()  # binarize
+    output = input.div(keep_prob) * random_tensor
+    return output
+
+
+class Dinov2WithRegistersDropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+    def __init__(self, drop_prob: Optional[float] = None) -> None:
+        super().__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        return drop_path(hidden_states, self.drop_prob, self.training)
+
+    def extra_repr(self) -> str:
+        return f"p={self.drop_prob}"
+
+
+class Dinov2WithRegistersMLP(nn.Module):
+    def __init__(self, config) -> None:
+        super().__init__()
+        in_features = out_features = config.hidden_size
+        hidden_features = int(config.hidden_size * config.mlp_ratio)
+        self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
+        if isinstance(config.hidden_act, str):
+            self.activation = ACT2FN[config.hidden_act]
+        else:
+            self.activation = config.hidden_act
+        self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
+
+    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+        hidden_state = self.fc1(hidden_state)
+        hidden_state = self.activation(hidden_state)
+        hidden_state = self.fc2(hidden_state)
+        return hidden_state
+
+
+class Dinov2WithRegistersSwiGLUFFN(nn.Module):
+    def __init__(self, config) -> None:
+        super().__init__()
+        in_features = out_features = config.hidden_size
+        hidden_features = int(config.hidden_size * config.mlp_ratio)
+        hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+
+        self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
+        self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
+
+    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+        hidden_state = self.weights_in(hidden_state)
+        x1, x2 = hidden_state.chunk(2, dim=-1)
+        hidden = nn.functional.silu(x1) * x2
+        return self.weights_out(hidden)
+
+
+class Dinov2WithRegistersLayer(GradientCheckpointingLayer):
+    """This corresponds to the Block class in the original implementation."""
+
+    def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+        super().__init__()
+
+        self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.attention = Dinov2WithRegistersAttention(config)
+        self.layer_scale1 = Dinov2WithRegistersLayerScale(config)
+        self.drop_path = (
+            Dinov2WithRegistersDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+        )
+
+        self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+        if config.use_swiglu_ffn:
+            self.mlp = Dinov2WithRegistersSwiGLUFFN(config)
+        else:
+            self.mlp = Dinov2WithRegistersMLP(config)
+        self.layer_scale2 = Dinov2WithRegistersLayerScale(config)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        head_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        hidden_states_norm = self.norm1(hidden_states)
+        self_attention_output = self.attention(hidden_states_norm, head_mask)
+        self_attention_output = self.layer_scale1(self_attention_output)
+
+        # first residual connection
+        hidden_states = self.drop_path(self_attention_output) + hidden_states
+
+        # in Dinov2WithRegisters, layernorm is also applied after self-attention
+        layer_output = self.norm2(hidden_states)
+        layer_output = self.mlp(layer_output)
+        layer_output = self.layer_scale2(layer_output)
+
+        # second residual connection
+        layer_output = self.drop_path(layer_output) + hidden_states
+
+        return layer_output
+
+
+class Dinov2WithRegistersEncoder(nn.Module):
+    def __init__(self, config: Dinov2WithRegistersConfig):
+        super().__init__()
+        self.config = config
+        self.layer = nn.ModuleList([Dinov2WithRegistersLayer(config) for _ in range(config.num_hidden_layers)])
+        self.gradient_checkpointing = False
+
+    def forward(
+        self, hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False
+    ) -> BaseModelOutput:
+        all_hidden_states = [hidden_states] if output_hidden_states else None
+        for i, layer_module in enumerate(self.layer):
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+            hidden_states = layer_module(hidden_states, layer_head_mask)
+            if all_hidden_states:
+                all_hidden_states.append(hidden_states)
+
+        return BaseModelOutput(
+            last_hidden_state=hidden_states,
+            hidden_states=tuple(all_hidden_states) if all_hidden_states else None,
+        )
+
+
+@auto_docstring
+class Dinov2WithRegistersPreTrainedModel(PreTrainedModel):
+    config: Dinov2WithRegistersConfig
+    base_model_prefix = "dinov2_with_registers"
+    main_input_name = "pixel_values"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["Dinov2WithRegistersLayer"]
+    _supports_sdpa = True
+    _supports_flash_attn = True
+    _supports_flex_attn = True
+    _supports_attention_backend = True
+    _can_record_outputs = {
+        "attentions": Dinov2WithRegistersSelfAttention,
+    }
+
+    def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+            # `trunc_normal_cpu` not implemented in `half` issues
+            module.weight.data = nn.init.trunc_normal_(
+                module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+            ).to(module.weight.dtype)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, Dinov2WithRegistersEmbeddings):
+            module.position_embeddings.data = nn.init.trunc_normal_(
+                module.position_embeddings.data.to(torch.float32),
+                mean=0.0,
+                std=self.config.initializer_range,
+            ).to(module.position_embeddings.dtype)
+
+            module.cls_token.data = nn.init.trunc_normal_(
+                module.cls_token.data.to(torch.float32),
+                mean=0.0,
+                std=self.config.initializer_range,
+            ).to(module.cls_token.dtype)
+
+            module.mask_token.data.zero_()
+            module.register_tokens.data.zero_()
+        elif isinstance(module, Dinov2WithRegistersLayerScale):  # noqa: F821
+            module.lambda1.data.fill_(self.config.layerscale_value)
+
+
+@auto_docstring
+class Dinov2WithRegistersModel(Dinov2WithRegistersPreTrainedModel):
+    def __init__(self, config: Dinov2WithRegistersConfig):
+        super().__init__(config)
+        self.config = config
+
+        self.embeddings = Dinov2WithRegistersEmbeddings(config)
+        self.encoder = Dinov2WithRegistersEncoder(config)
+
+        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
+        return self.embeddings.patch_embeddings
+
+    def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    @check_model_inputs(tie_last_hidden_states=False)
+    @auto_docstring
+    def forward(
+        self,
+        pixel_values: Optional[torch.Tensor] = None,
+        bool_masked_pos: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        output_hidden_states: Optional[bool] = None,
+        **kwargs,
+    ) -> BaseModelOutputWithPooling:
+        r"""
+        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
+            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
+            pre-training.
+        """
+        if output_hidden_states is None:
+            output_hidden_states = self.config.output_hidden_states
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+        embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
+
+        encoder_outputs: BaseModelOutput = self.encoder(
+            embedding_output, head_mask=head_mask, output_hidden_states=output_hidden_states
+        )
+        sequence_output = encoder_outputs.last_hidden_state
+        sequence_output = self.layernorm(sequence_output)
+        pooled_output = sequence_output[:, 0, :]
+
+        return BaseModelOutputWithPooling(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+        )
+
+
+@auto_docstring(
+    custom_intro="""
+    Dinov2WithRegisters Model transformer with an image classification head on top (a linear layer on top of the final hidden state
+    of the [CLS] token) e.g. for ImageNet.
+    """
+)
+class Dinov2WithRegistersForImageClassification(Dinov2WithRegistersPreTrainedModel):
+    def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+        super().__init__(config)
+
+        self.num_labels = config.num_labels
+        self.dinov2_with_registers = Dinov2WithRegistersModel(config)
+
+        # Classifier head
+        self.classifier = (
+            nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity()
+        )
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @can_return_tuple
+    @auto_docstring
+    def forward(
+        self,
+        pixel_values: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        labels: Optional[torch.Tensor] = None,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> ImageClassifierOutput:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+
+        outputs: BaseModelOutputWithPooling = self.dinov2_with_registers(pixel_values, head_mask=head_mask, **kwargs)
+        sequence_output = outputs.last_hidden_state  # batch_size, sequence_length, hidden_size
+
+        cls_token = sequence_output[:, 0]
+        # cls and register tokens should not be included in patch tokens variable
+        patch_tokens = sequence_output[:, 1 + self.config.num_register_tokens :]
+
+        linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
+        logits = self.classifier(linear_input)
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(labels, logits, self.config, **kwargs)
+
+        return ImageClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+@auto_docstring(
+    custom_intro="""
+    Dinov2WithRegisters backbone, to be used with frameworks like DETR and MaskFormer.
+    """
+)
+class Dinov2WithRegistersBackbone(Dinov2WithRegistersPreTrainedModel, BackboneMixin):
+    def __init__(self, config):
+        super().__init__(config)
+        super()._init_backbone(config)
+        self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
+        self.embeddings = Dinov2WithRegistersEmbeddings(config)
+        self.encoder = Dinov2WithRegistersEncoder(config)
+
+        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+        self.num_register_tokens = config.num_register_tokens
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
+        return self.embeddings.patch_embeddings
+
+    @check_model_inputs
+    @auto_docstring
+    def forward(
+        self,
+        pixel_values: torch.Tensor,
+        output_hidden_states: Optional[bool] = None,
+        **kwargs,
+    ) -> BackboneOutput:
+        r"""
+        Examples:
+
+        ```python
+        >>> from transformers import AutoImageProcessor, AutoBackbone
+        >>> import torch
+        >>> from PIL import Image
+        >>> import requests
+
+        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw)
+
+        >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base")
+        >>> model = AutoBackbone.from_pretrained(
+        ...     "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"]
+        ... )
+
+        >>> inputs = processor(image, return_tensors="pt")
+
+        >>> outputs = model(**inputs)
+        >>> feature_maps = outputs.feature_maps
+        >>> list(feature_maps[-1].shape)
+        [1, 768, 16, 16]
+        ```"""
+        if output_hidden_states is None:
+            output_hidden_states = self.config.output_hidden_states
+
+        embedding_output = self.embeddings(pixel_values)
+        output: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=True)
+        hidden_states = output.hidden_states
+
+        feature_maps = []
+        for stage, hidden_state in zip(self.stage_names, hidden_states):
+            if stage in self.out_features:
+                if self.config.apply_layernorm:
+                    hidden_state = self.layernorm(hidden_state)
+                if self.config.reshape_hidden_states:
+                    hidden_state = hidden_state[:, 1 + self.num_register_tokens :]
+                    # this was actually a bug in the original implementation that we copied here,
+                    # cause normally the order is height, width
+                    batch_size, _, height, width = pixel_values.shape
+                    patch_size = self.config.patch_size
+                    hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
+                    hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+                feature_maps.append(hidden_state)
+
+        return BackboneOutput(
+            feature_maps=tuple(feature_maps),
+            hidden_states=hidden_states if output_hidden_states else None,
+        )
+
+
+__all__ = [
+    "Dinov2WithRegistersPreTrainedModel",
+    "Dinov2WithRegistersModel",
+    "Dinov2WithRegistersForImageClassification",
+    "Dinov2WithRegistersBackbone",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py
new file mode 100644
index 0000000000000000000000000000000000000000..686528002b09c9689d66a057ed55eb1a43b0d256
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py
@@ -0,0 +1,435 @@
+# coding=utf-8
+# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ....transformers.models.dinov2.modeling_dinov2 import (
+    Dinov2Backbone,
+    Dinov2Encoder,
+    Dinov2ForImageClassification,
+    Dinov2Model,
+    Dinov2PatchEmbeddings,
+    Dinov2PreTrainedModel,
+)
+from ...configuration_utils import PretrainedConfig
+from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, logging, torch_int
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an
+    Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration
+    with the defaults will yield a similar configuration to that of the DINOv2 with Registers
+    [facebook/dinov2-with-registers-base](https://huggingface.co/facebook/dinov2-with-registers-base) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        hidden_size (`int`, *optional*, defaults to 768):
+            Dimensionality of the encoder layers and the pooler layer.
+        num_hidden_layers (`int`, *optional*, defaults to 12):
+            Number of hidden layers in the Transformer encoder.
+        num_attention_heads (`int`, *optional*, defaults to 12):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        mlp_ratio (`int`, *optional*, defaults to 4):
+            Ratio of the hidden size of the MLPs relative to the `hidden_size`.
+        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"selu"` and `"gelu_new"` are supported.
+        hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+            The epsilon used by the layer normalization layers.
+        image_size (`int`, *optional*, defaults to 224):
+            The size (resolution) of each image.
+        patch_size (`int`, *optional*, defaults to 16):
+            The size (resolution) of each patch.
+        num_channels (`int`, *optional*, defaults to 3):
+            The number of input channels.
+        qkv_bias (`bool`, *optional*, defaults to `True`):
+            Whether to add a bias to the queries, keys and values.
+        layerscale_value (`float`, *optional*, defaults to 1.0):
+           Initial value to use for layer scale.
+        drop_path_rate (`float`, *optional*, defaults to 0.0):
+            Stochastic depth rate per sample (when applied in the main path of residual layers).
+        use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
+            Whether to use the SwiGLU feedforward neural network.
+        num_register_tokens (`int`, *optional*, defaults to 4):
+            Number of register tokens to use.
+        out_features (`list[str]`, *optional*):
+            If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+            (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+            corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+            same order as defined in the `stage_names` attribute.
+        out_indices (`list[int]`, *optional*):
+            If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+            many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+            If unset and `out_features` is unset, will default to the last stage. Must be in the
+            same order as defined in the `stage_names` attribute.
+        apply_layernorm (`bool`, *optional*, defaults to `True`):
+            Whether to apply layer normalization to the feature maps in case the model is used as backbone.
+        reshape_hidden_states (`bool`, *optional*, defaults to `True`):
+            Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
+            case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
+            seq_len, hidden_size)`.
+
+    Example:
+
+    ```python
+    >>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel
+
+    >>> # Initializing a Dinov2WithRegisters base style configuration
+    >>> configuration = Dinov2WithRegistersConfig()
+
+    >>> # Initializing a model (with random weights) from the base style configuration
+    >>> model = Dinov2WithRegistersModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "dinov2_with_registers"
+
+    def __init__(
+        self,
+        hidden_size=768,
+        num_hidden_layers=12,
+        num_attention_heads=12,
+        mlp_ratio=4,
+        hidden_act="gelu",
+        hidden_dropout_prob=0.0,
+        attention_probs_dropout_prob=0.0,
+        initializer_range=0.02,
+        layer_norm_eps=1e-6,
+        image_size=224,
+        patch_size=16,
+        num_channels=3,
+        qkv_bias=True,
+        layerscale_value=1.0,
+        drop_path_rate=0.0,
+        use_swiglu_ffn=False,
+        num_register_tokens=4,
+        out_features=None,
+        out_indices=None,
+        apply_layernorm=True,
+        reshape_hidden_states=True,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.hidden_size = hidden_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.mlp_ratio = mlp_ratio
+        self.hidden_act = hidden_act
+        self.hidden_dropout_prob = hidden_dropout_prob
+        self.attention_probs_dropout_prob = attention_probs_dropout_prob
+        self.initializer_range = initializer_range
+        self.layer_norm_eps = layer_norm_eps
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.qkv_bias = qkv_bias
+        self.layerscale_value = layerscale_value
+        self.drop_path_rate = drop_path_rate
+        self.use_swiglu_ffn = use_swiglu_ffn
+        self.num_register_tokens = num_register_tokens
+        self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
+        self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+            out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+        )
+        self.apply_layernorm = apply_layernorm
+        self.reshape_hidden_states = reshape_hidden_states
+
+
+class Dinov2WithRegistersPatchEmbeddings(Dinov2PatchEmbeddings):
+    pass
+
+
+class Dinov2WithRegistersEmbeddings(nn.Module):
+    """
+    Construct the CLS token, mask token, register tokens, position and patch embeddings.
+    """
+
+    def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+        super().__init__()
+
+        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+        self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
+        self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
+        self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config)
+        num_patches = self.patch_embeddings.num_patches
+        self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+        self.patch_size = config.patch_size
+        self.config = config
+
+    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+        """
+        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+        resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
+        with the original implementation.
+
+        Adapted from:
+        - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+        - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
+        """
+        num_patches = embeddings.shape[1] - 1
+        num_positions = self.position_embeddings.shape[1] - 1
+
+        # Skip interpolation for matching dimensions (unless tracing)
+        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+            return self.position_embeddings
+
+        # Handle class token and patch embeddings separately
+        class_pos_embed = self.position_embeddings[:, 0]
+        patch_pos_embed = self.position_embeddings[:, 1:]
+        dim = embeddings.shape[-1]
+
+        # Calculate new dimensions
+        height = height // self.config.patch_size
+        width = width // self.config.patch_size
+
+        # Reshape for interpolation
+        sqrt_num_positions = torch_int(num_positions**0.5)
+        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+        # Store original dtype for restoration after interpolation
+        target_dtype = patch_pos_embed.dtype
+
+        # Interpolate at float32 precision
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed.to(dtype=torch.float32),
+            size=(torch_int(height), torch_int(width)),  # Explicit size instead of scale_factor
+            mode="bicubic",
+            align_corners=False,
+            antialias=True,
+        ).to(dtype=target_dtype)
+
+        # Validate output dimensions if not tracing
+        if not torch.jit.is_tracing():
+            if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
+                raise ValueError("Width or height does not match with the interpolated position embeddings")
+
+        # Reshape back to original format
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+        # Combine class and patch embeddings
+        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+    def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
+        batch_size, _, height, width = pixel_values.shape
+        target_dtype = self.patch_embeddings.projection.weight.dtype
+        embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
+
+        if bool_masked_pos is not None:
+            embeddings = torch.where(
+                bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
+            )
+
+        # add the [CLS] token to the embedded patch tokens
+        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+        embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+        # add positional encoding to each token
+        embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+
+        # add register tokens
+        embeddings = torch.cat(
+            (embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1), embeddings[:, 1:]), dim=1
+        )
+
+        embeddings = self.dropout(embeddings)
+
+        return embeddings
+
+
+class Dinov2WithRegistersEncoder(Dinov2Encoder):
+    pass
+
+
+class Dinov2WithRegistersPreTrainedModel(Dinov2PreTrainedModel):
+    def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+            # `trunc_normal_cpu` not implemented in `half` issues
+            module.weight.data = nn.init.trunc_normal_(
+                module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+            ).to(module.weight.dtype)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, Dinov2WithRegistersEmbeddings):
+            module.position_embeddings.data = nn.init.trunc_normal_(
+                module.position_embeddings.data.to(torch.float32),
+                mean=0.0,
+                std=self.config.initializer_range,
+            ).to(module.position_embeddings.dtype)
+
+            module.cls_token.data = nn.init.trunc_normal_(
+                module.cls_token.data.to(torch.float32),
+                mean=0.0,
+                std=self.config.initializer_range,
+            ).to(module.cls_token.dtype)
+
+            module.mask_token.data.zero_()
+            module.register_tokens.data.zero_()
+        elif isinstance(module, Dinov2WithRegistersLayerScale):  # noqa: F821
+            module.lambda1.data.fill_(self.config.layerscale_value)
+
+
+class Dinov2WithRegistersModel(Dinov2Model):
+    pass
+
+
+class Dinov2WithRegistersForImageClassification(Dinov2ForImageClassification):
+    def forward(
+        self,
+        pixel_values: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        labels: Optional[torch.Tensor] = None,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> ImageClassifierOutput:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+
+        outputs: BaseModelOutputWithPooling = self.dinov2_with_registers(pixel_values, head_mask=head_mask, **kwargs)
+        sequence_output = outputs.last_hidden_state  # batch_size, sequence_length, hidden_size
+
+        cls_token = sequence_output[:, 0]
+        # cls and register tokens should not be included in patch tokens variable
+        patch_tokens = sequence_output[:, 1 + self.config.num_register_tokens :]
+
+        linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
+        logits = self.classifier(linear_input)
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(labels, logits, self.config, **kwargs)
+
+        return ImageClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+class Dinov2WithRegistersBackbone(Dinov2Backbone):
+    def __init__(self, config):
+        super().__init__(config)
+        super()._init_backbone(config)
+
+        self.num_register_tokens = config.num_register_tokens
+        self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
+        self.embeddings = Dinov2WithRegistersEmbeddings(config)
+        self.encoder = Dinov2WithRegistersEncoder(config)
+
+        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
+        return self.embeddings.patch_embeddings
+
+    def forward(
+        self,
+        pixel_values: torch.Tensor,
+        output_hidden_states: Optional[bool] = None,
+        **kwargs,
+    ) -> BackboneOutput:
+        r"""
+        Examples:
+
+        ```python
+        >>> from transformers import AutoImageProcessor, AutoBackbone
+        >>> import torch
+        >>> from PIL import Image
+        >>> import requests
+
+        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw)
+
+        >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base")
+        >>> model = AutoBackbone.from_pretrained(
+        ...     "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"]
+        ... )
+
+        >>> inputs = processor(image, return_tensors="pt")
+
+        >>> outputs = model(**inputs)
+        >>> feature_maps = outputs.feature_maps
+        >>> list(feature_maps[-1].shape)
+        [1, 768, 16, 16]
+        ```"""
+        if output_hidden_states is None:
+            output_hidden_states = self.config.output_hidden_states
+
+        embedding_output = self.embeddings(pixel_values)
+        output: BaseModelOutput = self.encoder(embedding_output, output_hidden_states=True)
+        hidden_states = output.hidden_states
+
+        feature_maps = []
+        for stage, hidden_state in zip(self.stage_names, hidden_states):
+            if stage in self.out_features:
+                if self.config.apply_layernorm:
+                    hidden_state = self.layernorm(hidden_state)
+                if self.config.reshape_hidden_states:
+                    hidden_state = hidden_state[:, 1 + self.num_register_tokens :]
+                    # this was actually a bug in the original implementation that we copied here,
+                    # cause normally the order is height, width
+                    batch_size, _, height, width = pixel_values.shape
+                    patch_size = self.config.patch_size
+                    hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
+                    hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+                feature_maps.append(hidden_state)
+
+        return BackboneOutput(
+            feature_maps=tuple(feature_maps),
+            hidden_states=hidden_states if output_hidden_states else None,
+        )
+
+
+__all__ = [
+    "Dinov2WithRegistersConfig",
+    "Dinov2WithRegistersPreTrainedModel",
+    "Dinov2WithRegistersModel",
+    "Dinov2WithRegistersForImageClassification",
+    "Dinov2WithRegistersBackbone",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dit/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/dit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c786feb9213fdd31640c0fdeaead5164026ad37a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_encoder_decoder import *
+    from .modeling_encoder_decoder import *
+    from .modeling_flax_encoder_decoder import *
+    from .modeling_tf_encoder_decoder import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/configuration_encoder_decoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/configuration_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..af57b2596cee99eefe0493cc4aea51c845036d2e
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/configuration_encoder_decoder.py
@@ -0,0 +1,111 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto import AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class EncoderDecoderConfig(PretrainedConfig):
+    r"""
+    [`EncoderDecoderConfig`] is the configuration class to store the configuration of a [`EncoderDecoderModel`]. It is
+    used to instantiate an Encoder Decoder model according to the specified arguments, defining the encoder and decoder
+    configs.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        kwargs (*optional*):
+            Dictionary of keyword arguments. Notably:
+
+                - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
+                  the encoder config.
+                - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
+                  the decoder config.
+
+    Examples:
+
+    ```python
+    >>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel
+
+    >>> # Initializing a BERT google-bert/bert-base-uncased style configuration
+    >>> config_encoder = BertConfig()
+    >>> config_decoder = BertConfig()
+
+    >>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
+
+    >>> # Initializing a Bert2Bert model (with random weights) from the google-bert/bert-base-uncased style configurations
+    >>> model = EncoderDecoderModel(config=config)
+
+    >>> # Accessing the model configuration
+    >>> config_encoder = model.config.encoder
+    >>> config_decoder = model.config.decoder
+    >>> # set decoder config to causal lm
+    >>> config_decoder.is_decoder = True
+    >>> config_decoder.add_cross_attention = True
+
+    >>> # Saving the model, including its configuration
+    >>> model.save_pretrained("my-model")
+
+    >>> # loading model and config from pretrained folder
+    >>> encoder_decoder_config = EncoderDecoderConfig.from_pretrained("my-model")
+    >>> model = EncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
+    ```"""
+
+    model_type = "encoder-decoder"
+    sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig}
+    has_no_defaults_at_init = True
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        if "encoder" not in kwargs or "decoder" not in kwargs:
+            raise ValueError(
+                f"A configuration of type {self.model_type} cannot be instantiated because "
+                f"both `encoder` and `decoder` sub-configurations were not passed, only {kwargs}"
+            )
+        encoder_config = kwargs.pop("encoder")
+        encoder_model_type = encoder_config.pop("model_type")
+        decoder_config = kwargs.pop("decoder")
+        decoder_model_type = decoder_config.pop("model_type")
+
+        self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
+        self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)
+        self.is_encoder_decoder = True
+
+    @classmethod
+    def from_encoder_decoder_configs(
+        cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
+    ) -> PretrainedConfig:
+        r"""
+        Instantiate a [`EncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model configuration and
+        decoder model configuration.
+
+        Returns:
+            [`EncoderDecoderConfig`]: An instance of a configuration object
+        """
+        logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
+        decoder_config.is_decoder = True
+        decoder_config.add_cross_attention = True
+
+        return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
+
+
+__all__ = ["EncoderDecoderConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..30e2370b2240d2f2f00182abea548cd6a72b5626
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py
@@ -0,0 +1,609 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Classes to support Encoder-Decoder architectures"""
+
+import gc
+import inspect
+import os
+import tempfile
+import warnings
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...cache_utils import Cache
+from ...configuration_utils import PretrainedConfig
+from ...generation import GenerationMixin
+from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, logging
+from ..auto.configuration_auto import AutoConfig
+from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM
+from .configuration_encoder_decoder import EncoderDecoderConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+DEPRECATION_WARNING = (
+    "Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the"
+    " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if"
+    " fine-tuning a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the"
+    " labels, no need to pass them yourself anymore."
+)
+
+
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+    """
+    Shift input ids one token to the right.
+    """
+    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+    if decoder_start_token_id is None:
+        raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
+    shifted_input_ids[:, 0] = decoder_start_token_id
+
+    if pad_token_id is None:
+        raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
+    # replace possible -100 values in labels by `pad_token_id`
+    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+    return shifted_input_ids
+
+
+@auto_docstring
+class EncoderDecoderModel(PreTrainedModel, GenerationMixin):
+    r"""
+    [`EncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
+    of the base model classes of the library as encoder and another one as decoder when created with the
+    :meth*~transformers.AutoModel.from_pretrained* class method for the encoder and
+    :meth*~transformers.AutoModelForCausalLM.from_pretrained* class method for the decoder.
+    """
+
+    config: EncoderDecoderConfig
+    base_model_prefix = "encoder_decoder"
+    main_input_name = "input_ids"
+    supports_gradient_checkpointing = True
+    _supports_param_buffer_assignment = False
+    _supports_flash_attn = True
+    _supports_sdpa = True
+
+    def __init__(
+        self,
+        config: Optional[PretrainedConfig] = None,
+        encoder: Optional[PreTrainedModel] = None,
+        decoder: Optional[PreTrainedModel] = None,
+    ):
+        r"""
+        encoder (`PreTrainedModel`, *optional*):
+            The encoder model to use.
+        decoder (`PreTrainedModel`, *optional*):
+            The decoder model to use.
+        """
+        if config is None and (encoder is None or decoder is None):
+            raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
+        if config is None:
+            config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
+        else:
+            if not isinstance(config, self.config_class):
+                raise ValueError(f"Config: {config} has to be of type {self.config_class}")
+
+        if config.decoder.cross_attention_hidden_size is not None:
+            if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
+                raise ValueError(
+                    "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+                    f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+                    f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+                    " `config.encoder.hidden_size`."
+                )
+
+        # initialize with config
+        super().__init__(config)
+
+        if encoder is None:
+            from ..auto.modeling_auto import AutoModel
+
+            encoder = AutoModel.from_config(config.encoder)
+
+        if decoder is None:
+            from ..auto.modeling_auto import AutoModelForCausalLM
+
+            decoder = AutoModelForCausalLM.from_config(config.decoder)
+
+        self.encoder = encoder
+        self.decoder = decoder
+
+        if self.encoder.config.to_dict() != self.config.encoder.to_dict():
+            logger.warning(
+                f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
+                f" {self.config.encoder}"
+            )
+        if self.decoder.config.to_dict() != self.config.decoder.to_dict():
+            logger.warning(
+                f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+                f" {self.config.decoder}"
+            )
+
+        # make sure that the individual model's config refers to the shared config
+        # so that the updates to the config will be synced
+        # update `_attn_implementation` because the attn is set in a deepcopied config within PreTrainedModel
+        self.config.encoder._attn_implementation = self.encoder.config._attn_implementation
+        self.config.decoder._attn_implementation = self.decoder.config._attn_implementation
+        self.encoder.config = self.config.encoder
+        self.decoder.config = self.config.decoder
+
+        # encoder outputs might need to be projected to different dimension for decoder
+        if (
+            self.encoder.config.hidden_size != self.decoder.config.hidden_size
+            and self.decoder.config.cross_attention_hidden_size is None
+        ):
+            self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
+
+        if self.encoder.get_output_embeddings() is not None:
+            raise ValueError(
+                f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
+            )
+
+        decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys())
+        if "encoder_hidden_states" not in decoder_signature:
+            raise ValueError(
+                "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
+                "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
+            )
+
+        # tie encoder, decoder weights if config set accordingly
+        self.tie_weights()
+
+    def tie_weights(self):
+        self.encoder.tie_weights()
+        self.decoder.tie_weights()
+        # tie encoder & decoder if needed
+        if self.config.tie_encoder_decoder:
+            # tie encoder and decoder base model
+            decoder_base_model_prefix = self.decoder.base_model_prefix
+            tied_weights = self._tie_encoder_decoder_weights(
+                self.encoder,
+                self.decoder._modules[decoder_base_model_prefix],
+                self.decoder.base_model_prefix,
+                "encoder",
+            )
+            # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
+            # attributed not an instance member, therefore modifying it will modify the entire class
+            # Leading to issues on subsequent calls by different tests or subsequent calls.
+            self._dynamic_tied_weights_keys = tied_weights
+
+    def _init_weights(self, module):
+        if module in self.encoder.modules():
+            self.encoder._init_weights(module)
+        elif module in self.decoder.modules():
+            self.decoder._init_weights(module)
+
+    def get_encoder(self):
+        return self.encoder
+
+    def get_input_embeddings(self):
+        return self.encoder.get_input_embeddings()
+
+    def get_output_embeddings(self):
+        return self.decoder.get_output_embeddings()
+
+    def set_output_embeddings(self, new_embeddings):
+        return self.decoder.set_output_embeddings(new_embeddings)
+
+    @classmethod
+    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
+        r"""
+        Example:
+
+        ```python
+        >>> from transformers import EncoderDecoderModel
+
+        >>> model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
+        ```"""
+
+        from_tf = kwargs.pop("from_tf", False)
+        if from_tf:
+            from transformers import TFEncoderDecoderModel
+
+            # a workaround to load from tensorflow checkpoint
+            # Using `_tf_model` won't work, because the weight names in the encoder/decoder of `_tf_model` get
+            # extended before saving those components. For example, The name of `_tf_model.encoder.vit` is
+            # `[top model name]/encoder/vit`, but the name of `tf_model.encoder.vit` is `[top model name]/vit`. The
+            # [top model name] is handled (stripped) by the conversion method, and the former case gets extra `encoder`,
+            # which should not occur when we want to save the components alone.
+            # There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see
+            #   https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245
+            #   (the change in `src/transformers/modeling_tf_utils.py`)
+            _tf_model = TFEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+            config = _tf_model.config
+
+            # Using `tf_model` instead
+            encoder = _tf_model.encoder.__class__(_tf_model.config.encoder)
+            decoder = _tf_model.decoder.__class__(_tf_model.config.decoder)
+            # Make sure models are built
+            encoder(encoder.dummy_inputs)
+            decoder(decoder.dummy_inputs)
+
+            # Get the variable correspondence between `_tf_model` and `encoder` and `decoder`
+            encoder_variables = {}
+            for v in encoder.trainable_variables + encoder.non_trainable_variables:
+                encoder_variables["/".join(v.name.split("/")[1:])] = v
+            decoder_variables = {}
+            for v in decoder.trainable_variables + decoder.non_trainable_variables:
+                decoder_variables["/".join(v.name.split("/")[1:])] = v
+
+            _encoder_variables = {}
+            for v in _tf_model.encoder.trainable_variables + _tf_model.encoder.non_trainable_variables:
+                _encoder_variables["/".join(v.name.split("/")[2:])] = v
+            _decoder_variables = {}
+            for v in _tf_model.decoder.trainable_variables + _tf_model.decoder.non_trainable_variables:
+                _decoder_variables["/".join(v.name.split("/")[2:])] = v
+
+            # assign weight values to `encoder` and `decoder` from `_tf_model`
+            for name, v in encoder_variables.items():
+                v.assign(_encoder_variables[name])
+            for name, v in decoder_variables.items():
+                v.assign(_decoder_variables[name])
+
+            tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
+
+            # Deal with `enc_to_dec_proj`
+            if hasattr(_tf_model, "enc_to_dec_proj"):
+                tf_model(tf_model.dummy_inputs)
+                tf_model.enc_to_dec_proj.kernel.assign(_tf_model.enc_to_dec_proj.kernel)
+                tf_model.enc_to_dec_proj.bias.assign(_tf_model.enc_to_dec_proj.bias)
+
+            with tempfile.TemporaryDirectory() as tmpdirname:
+                encoder_dir = os.path.join(tmpdirname, "encoder")
+                decoder_dir = os.path.join(tmpdirname, "decoder")
+                tf_model.encoder.save_pretrained(encoder_dir)
+                tf_model.decoder.save_pretrained(decoder_dir)
+
+                if hasattr(tf_model, "enc_to_dec_proj"):
+                    enc_to_dec_proj_weight = torch.transpose(
+                        torch.from_numpy(tf_model.enc_to_dec_proj.kernel.numpy()), 1, 0
+                    )
+                    enc_to_dec_proj_bias = torch.from_numpy(tf_model.enc_to_dec_proj.bias.numpy())
+
+                del _tf_model
+                del tf_model
+                gc.collect()
+
+                model = EncoderDecoderModel.from_encoder_decoder_pretrained(
+                    encoder_dir, decoder_dir, encoder_from_tf=True, decoder_from_tf=True
+                )
+                # This is only for copying some specific attributes of this particular model.
+                model.config = config
+
+                if hasattr(model, "enc_to_dec_proj"):
+                    model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight.contiguous()
+                    model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias.contiguous()
+
+                return model
+
+        return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+
+    @classmethod
+    def from_encoder_decoder_pretrained(
+        cls,
+        encoder_pretrained_model_name_or_path: Optional[str] = None,
+        decoder_pretrained_model_name_or_path: Optional[str] = None,
+        *model_args,
+        **kwargs,
+    ) -> PreTrainedModel:
+        r"""
+        Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
+        checkpoints.
+
+
+        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+        the model, you need to first set it back in training mode with `model.train()`.
+
+        Params:
+            encoder_pretrained_model_name_or_path (`str`, *optional*):
+                Information necessary to initiate the encoder. Can be either:
+
+                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+                    - A path to a *directory* containing model weights saved using
+                      [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+                    - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
+                      this case, `from_tf` should be set to `True` and a configuration object should be provided as
+                      `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
+                      PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+
+            decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
+                Information necessary to initiate the decoder. Can be either:
+
+                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+                    - A path to a *directory* containing model weights saved using
+                      [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+                    - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
+                      this case, `from_tf` should be set to `True` and a configuration object should be provided as
+                      `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
+                      PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+
+            model_args (remaining positional arguments, *optional*):
+                All remaining positional arguments will be passed to the underlying model's `__init__` method.
+
+            kwargs (remaining dictionary of keyword arguments, *optional*):
+                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+                `output_attentions=True`).
+
+                - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
+                - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
+                - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+                Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+        Example:
+
+        ```python
+        >>> from transformers import EncoderDecoderModel
+
+        >>> # initialize a bert2bert from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized
+        >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "google-bert/bert-base-uncased")
+        >>> # saving model after fine-tuning
+        >>> model.save_pretrained("./bert2bert")
+        >>> # load fine-tuned model
+        >>> model = EncoderDecoderModel.from_pretrained("./bert2bert")
+        ```"""
+
+        kwargs_encoder = {
+            argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
+        }
+
+        kwargs_decoder = {
+            argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+        }
+
+        # remove encoder, decoder kwargs from kwargs
+        for key in kwargs_encoder:
+            del kwargs["encoder_" + key]
+        for key in kwargs_decoder:
+            del kwargs["decoder_" + key]
+
+        # Load and initialize the encoder and decoder
+        # The distinction between encoder and decoder at the model level is made
+        # by the value of the flag `is_decoder` that we need to set correctly.
+        encoder = kwargs_encoder.pop("model", None)
+        if encoder is None:
+            if encoder_pretrained_model_name_or_path is None:
+                raise ValueError(
+                    "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
+                    "to be defined."
+                )
+
+            if "config" not in kwargs_encoder:
+                encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
+                    encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
+                )
+
+                if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+                    logger.info(
+                        f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
+                        "from a decoder model. Cross-attention and causal mask are disabled."
+                    )
+                    encoder_config.is_decoder = False
+                    encoder_config.add_cross_attention = False
+
+                kwargs_encoder["config"] = encoder_config
+
+            encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
+
+        decoder = kwargs_decoder.pop("model", None)
+        if decoder is None:
+            if decoder_pretrained_model_name_or_path is None:
+                raise ValueError(
+                    "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
+                    "to be defined."
+                )
+
+            if "config" not in kwargs_decoder:
+                decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
+                    decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
+                )
+
+                if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
+                    logger.info(
+                        f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+                        f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+                        f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
+                    )
+                    decoder_config.is_decoder = True
+                    decoder_config.add_cross_attention = True
+
+                kwargs_decoder["config"] = decoder_config
+
+            if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
+                logger.warning(
+                    f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
+                    f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
+                    "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
+                    "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
+                    "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
+                )
+
+            decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
+
+        # instantiate config with corresponding kwargs
+        config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
+        return cls(encoder=encoder, decoder=decoder, config=config)
+
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        decoder_attention_mask: Optional[torch.BoolTensor] = None,
+        encoder_outputs: Optional[tuple[torch.FloatTensor]] = None,
+        past_key_values: Optional[Cache] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        **kwargs,
+    ) -> Union[tuple, Seq2SeqLMOutput]:
+        r"""
+        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Indices of decoder input sequence tokens in the vocabulary.
+
+            Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+
+            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+            `past_key_values`).
+
+            For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the
+            right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`.
+        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+            be used by default.
+        decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+            representation. This is useful if you want more control over how to convert `decoder_input_ids` indices
+            into associated vectors than the model's internal embedding lookup matrix.
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0,
+            ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored
+            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+
+        Examples:
+
+        ```python
+        >>> from transformers import EncoderDecoderModel, BertTokenizer
+        >>> import torch
+
+        >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
+        >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained(
+        ...     "google-bert/bert-base-uncased", "google-bert/bert-base-uncased"
+        ... )  # initialize Bert2Bert from pre-trained checkpoints
+
+        >>> # training
+        >>> model.config.decoder_start_token_id = tokenizer.cls_token_id
+        >>> model.config.pad_token_id = tokenizer.pad_token_id
+        >>> model.config.vocab_size = model.config.decoder.vocab_size
+
+        >>> input_ids = tokenizer("This is a really long text", return_tensors="pt").input_ids
+        >>> labels = tokenizer("This is the corresponding summary", return_tensors="pt").input_ids
+        >>> outputs = model(input_ids=input_ids, labels=labels)
+        >>> loss, logits = outputs.loss, outputs.logits
+
+        >>> # save and load from pretrained
+        >>> model.save_pretrained("bert2bert")
+        >>> model = EncoderDecoderModel.from_pretrained("bert2bert")
+
+        >>> # generation
+        >>> generated = model.generate(input_ids)
+        ```"""
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
+
+        kwargs_decoder = {
+            argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+        }
+        if "num_items_in_batch" in kwargs_encoder:
+            kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None)
+
+        if encoder_outputs is None:
+            encoder_outputs = self.encoder(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                inputs_embeds=inputs_embeds,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+                **kwargs_encoder,
+            )
+        elif isinstance(encoder_outputs, tuple):
+            encoder_outputs = BaseModelOutput(*encoder_outputs)
+
+        encoder_hidden_states = encoder_outputs[0]
+
+        # optionally project encoder_hidden_states
+        if (
+            self.encoder.config.hidden_size != self.decoder.config.hidden_size
+            and self.decoder.config.cross_attention_hidden_size is None
+        ):
+            encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+        if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
+            decoder_input_ids = shift_tokens_right(
+                labels, self.config.pad_token_id, self.config.decoder_start_token_id
+            )
+            if decoder_attention_mask is None:
+                decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)
+
+        # Decode
+        decoder_outputs = self.decoder(
+            input_ids=decoder_input_ids,
+            attention_mask=decoder_attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=attention_mask,
+            inputs_embeds=decoder_inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            use_cache=use_cache,
+            past_key_values=past_key_values,
+            return_dict=return_dict,
+            **kwargs_decoder,
+        )
+
+        # Compute loss independent from decoder (as some shift the logits inside them)
+        loss = None
+        if labels is not None:
+            warnings.warn(DEPRECATION_WARNING, FutureWarning)
+            logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
+
+        if not return_dict:
+            if loss is not None:
+                return (loss,) + decoder_outputs + encoder_outputs
+            else:
+                return decoder_outputs + encoder_outputs
+
+        return Seq2SeqLMOutput(
+            loss=loss,
+            logits=decoder_outputs.logits,
+            past_key_values=decoder_outputs.past_key_values,
+            decoder_hidden_states=decoder_outputs.hidden_states,
+            decoder_attentions=decoder_outputs.attentions,
+            cross_attentions=decoder_outputs.cross_attentions,
+            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+            encoder_hidden_states=encoder_outputs.hidden_states,
+            encoder_attentions=encoder_outputs.attentions,
+        )
+
+    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+    def resize_token_embeddings(self, *args, **kwargs):
+        raise NotImplementedError(
+            "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
+            " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
+            " model.decoder.resize_token_embeddings(...))"
+        )
+
+
+__all__ = ["EncoderDecoderModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a27c23c3c69ae928c73273c9397d5f5aad2b1c0
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
@@ -0,0 +1,901 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Classes to support Flax Encoder-Decoder architectures"""
+
+import os
+from typing import Optional, Union
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+from jax.random import PRNGKey
+
+from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
+from ...modeling_flax_utils import FlaxPreTrainedModel
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from ..auto.configuration_auto import AutoConfig
+from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM
+from .configuration_encoder_decoder import EncoderDecoderConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "EncoderDecoderConfig"
+
+ENCODER_DECODER_START_DOCSTRING = r"""
+    This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
+    encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
+    [`~AutoModel.from_pretrained`] function and the decoder is loaded via [`~AutoModelForCausalLM.from_pretrained`]
+    function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream
+    generative task, like summarization.
+
+    The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
+    tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
+    Tasks](https://huggingface.co/papers/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
+    Zhou, Wei Li, Peter J. Liu.
+
+    After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models
+    (see the examples for more information).
+
+    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a Flax Linen
+    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+    Parameters:
+        config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+            `jax.numpy.bfloat16` (on TPUs).
+
+            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+            specified all the computation will be performed with the given `dtype`.
+
+            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+            parameters.**
+
+            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+            [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+ENCODER_DECODER_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Indices of decoder input sequence tokens in the vocabulary.
+
+            Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+            For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
+            created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
+            and prepending them with the `decoder_start_token_id`.
+        decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+            be used by default.
+        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.encoder.max_position_embeddings - 1]`.
+        decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+            range `[0, config.decoder.max_position_embeddings - 1]`.
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
+"""
+
+ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.encoder.max_position_embeddings - 1]`.
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple.
+"""
+
+ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
+    Args:
+        decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Indices of decoder input sequence tokens in the vocabulary.
+
+            Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+            `past_key_values`).
+
+            For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
+            created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
+            and prepending them with the `decoder_start_token_id`.
+        encoder_outputs (`tuple(tuple(jnp.ndarray)`):
+            Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+            hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+        encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+            be used by default.
+        decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+            range `[0, config.decoder.max_position_embeddings - 1]`.
+        past_key_values (`dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
+            Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
+            auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a
+            plain tuple.
+"""
+
+
+class FlaxEncoderDecoderModule(nn.Module):
+    config: EncoderDecoderConfig
+    dtype: jnp.dtype = jnp.float32
+
+    def setup(self):
+        encoder_config = self.config.encoder
+        decoder_config = self.config.decoder
+
+        # Copied from `modeling_hybrid_clip.py` with modifications.
+        from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING
+
+        encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class
+        decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class
+
+        self.encoder = encoder_module(encoder_config, dtype=self.dtype)
+        self.decoder = decoder_module(decoder_config, dtype=self.dtype)
+
+        # encoder outputs might need to be projected to different dimension for decoder
+        if (
+            self.encoder.config.hidden_size != self.decoder.config.hidden_size
+            and self.decoder.config.cross_attention_hidden_size is None
+        ):
+            self.enc_to_dec_proj = nn.Dense(
+                self.decoder.config.hidden_size,
+                kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
+                dtype=self.dtype,
+            )
+        else:
+            self.enc_to_dec_proj = None
+
+    def _get_encoder_module(self):
+        return self.encoder
+
+    def _get_projection_module(self):
+        return self.enc_to_dec_proj
+
+    def _get_decoder_module(self):
+        return self.decoder
+
+    def __call__(
+        self,
+        input_ids,
+        attention_mask,
+        decoder_input_ids,
+        decoder_attention_mask,
+        position_ids,
+        decoder_position_ids,
+        output_attentions: bool = False,
+        output_hidden_states: bool = False,
+        return_dict: bool = True,
+        deterministic: bool = True,
+    ):
+        encoder_outputs = self.encoder(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            deterministic=deterministic,
+        )
+
+        encoder_hidden_states = encoder_outputs[0]
+
+        # optionally project encoder_hidden_states
+        if self.enc_to_dec_proj is not None:
+            encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+        decoder_outputs = self.decoder(
+            input_ids=decoder_input_ids,
+            attention_mask=decoder_attention_mask,
+            position_ids=decoder_position_ids,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            deterministic=deterministic,
+        )
+
+        if not return_dict:
+            return decoder_outputs + encoder_outputs
+
+        return FlaxSeq2SeqLMOutput(
+            logits=decoder_outputs.logits,
+            decoder_hidden_states=decoder_outputs.hidden_states,
+            decoder_attentions=decoder_outputs.attentions,
+            cross_attentions=decoder_outputs.cross_attentions,
+            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+            encoder_hidden_states=encoder_outputs.hidden_states,
+            encoder_attentions=encoder_outputs.attentions,
+        )
+
+
+@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
+class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
+    r"""
+    [`FlaxEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with
+    the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one as
+    decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the
+    encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
+    """
+
+    config_class = EncoderDecoderConfig
+    base_model_prefix = "encoder_decoder"
+    module_class = FlaxEncoderDecoderModule
+
+    def __init__(
+        self,
+        config: EncoderDecoderConfig,
+        input_shape: Optional[tuple] = None,
+        seed: int = 0,
+        dtype: jnp.dtype = jnp.float32,
+        _do_init: bool = True,
+        **kwargs,
+    ):
+        if input_shape is None:
+            input_shape = ((1, 1), (1, 1))
+
+        if not _do_init:
+            raise ValueError(
+                "`FlaxEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
+            )
+
+        if config.decoder.cross_attention_hidden_size is not None:
+            if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
+                raise ValueError(
+                    "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+                    f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+                    f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+                    " `config.encoder.hidden_size`."
+                )
+
+        module = self.module_class(config=config, dtype=dtype, **kwargs)
+        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+    def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict:
+        encoder_input_shape, decoder_input_shape = input_shape
+
+        # init input tensors
+        input_ids = jnp.zeros(encoder_input_shape, dtype="i4")
+        attention_mask = jnp.ones_like(input_ids)
+        decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
+        decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+
+        batch_size, sequence_length = input_ids.shape
+        position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+        decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
+        if not decoder_batch_size == batch_size:
+            raise ValueError(
+                f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder"
+                f" and {decoder_batch_size} for decoder."
+            )
+        decoder_position_ids = jnp.broadcast_to(
+            jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
+        )
+
+        params_rng, dropout_rng = jax.random.split(rng)
+        rngs = {"params": params_rng, "dropout": dropout_rng}
+
+        random_params = self.module.init(
+            rngs,
+            input_ids,
+            attention_mask,
+            decoder_input_ids,
+            decoder_attention_mask,
+            position_ids,
+            decoder_position_ids,
+        )["params"]
+
+        if params is not None:
+            random_params = flatten_dict(unfreeze(random_params))
+            params = flatten_dict(unfreeze(params))
+            for missing_key in self._missing_keys:
+                params[missing_key] = random_params[missing_key]
+            self._missing_keys = set()
+            return freeze(unflatten_dict(params))
+        else:
+            return random_params
+
+    def init_cache(self, batch_size, max_length, encoder_outputs):
+        r"""
+        Args:
+            batch_size (`int`):
+                batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+            max_length (`int`):
+                maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+                cache.
+            encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
+                `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
+                `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
+                is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
+                cross-attention of the decoder.
+        """
+        # init input variables to retrieve cache
+        decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+        decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+        decoder_position_ids = jnp.broadcast_to(
+            jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
+        )
+
+        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
+            decoder_module = module._get_decoder_module()
+            return decoder_module(
+                input_ids=decoder_input_ids,
+                attention_mask=decoder_attention_mask,
+                position_ids=decoder_position_ids,
+                **kwargs,
+            )
+
+        init_variables = self.module.init(
+            jax.random.PRNGKey(0),
+            decoder_input_ids=decoder_input_ids,
+            decoder_attention_mask=decoder_attention_mask,
+            decoder_position_ids=decoder_position_ids,
+            encoder_hidden_states=encoder_outputs[0],
+            init_cache=True,
+            method=_decoder_forward,  # we only need to call the decoder to init the cache
+        )
+        return unfreeze(init_variables["cache"])
+
+    @add_start_docstrings(ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
+    def encode(
+        self,
+        input_ids: jnp.ndarray,
+        attention_mask: Optional[jnp.ndarray] = None,
+        position_ids: Optional[jnp.ndarray] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        train: bool = False,
+        params: Optional[dict] = None,
+        dropout_rng: PRNGKey = None,
+    ):
+        r"""
+        Returns:
+
+        Example:
+
+        ```python
+        >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer
+
+        >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
+        >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
+
+        >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
+
+        >>> text = "My friends are cool but they eat too many carbs."
+        >>> input_ids = tokenizer.encode(text, return_tensors="np")
+        >>> encoder_outputs = model.encode(input_ids)
+        ```"""
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+        if attention_mask is None:
+            attention_mask = jnp.ones_like(input_ids)
+        if position_ids is None:
+            batch_size, sequence_length = input_ids.shape
+            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+        # Handle any PRNG if needed
+        rngs = {}
+        if dropout_rng is not None:
+            rngs["dropout"] = dropout_rng
+
+        def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
+            encode_module = module._get_encoder_module()
+            return encode_module(input_ids, attention_mask, position_ids, **kwargs)
+
+        outputs = self.module.apply(
+            {"params": params or self.params},
+            input_ids=jnp.array(input_ids, dtype="i4"),
+            attention_mask=jnp.array(attention_mask, dtype="i4"),
+            position_ids=jnp.array(position_ids, dtype="i4"),
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            deterministic=not train,
+            rngs=rngs,
+            method=_encoder_forward,
+        )
+
+        if return_dict:
+            outputs = FlaxBaseModelOutput(
+                last_hidden_state=outputs.last_hidden_state,
+                hidden_states=outputs.hidden_states,
+                attentions=outputs.attentions,
+            )
+
+        return outputs
+
+    @add_start_docstrings(ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
+    def decode(
+        self,
+        decoder_input_ids,
+        encoder_outputs,
+        encoder_attention_mask: Optional[jnp.ndarray] = None,
+        decoder_attention_mask: Optional[jnp.ndarray] = None,
+        decoder_position_ids: Optional[jnp.ndarray] = None,
+        past_key_values: Optional[dict] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        train: bool = False,
+        params: Optional[dict] = None,
+        dropout_rng: PRNGKey = None,
+    ):
+        r"""
+        Returns:
+
+        Example:
+
+        ```python
+        >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer
+        >>> import jax.numpy as jnp
+
+        >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
+        >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
+
+        >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
+
+        >>> text = "My friends are cool but they eat too many carbs."
+        >>> input_ids = tokenizer.encode(text, max_length=1024, return_tensors="np")
+        >>> encoder_outputs = model.encode(input_ids)
+
+        >>> decoder_start_token_id = model.config.decoder.bos_token_id
+        >>> decoder_input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
+
+        >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
+        >>> logits = outputs.logits
+        ```"""
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+        encoder_hidden_states = encoder_outputs[0]
+        if encoder_attention_mask is None:
+            batch_size, sequence_length = encoder_hidden_states.shape[:2]
+            encoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+        batch_size, sequence_length = decoder_input_ids.shape
+        if decoder_attention_mask is None:
+            decoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+        if decoder_position_ids is None:
+            if past_key_values is not None:
+                raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
+
+            decoder_position_ids = jnp.broadcast_to(
+                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
+            )
+
+        # Handle any PRNG if needed
+        rngs = {}
+        if dropout_rng is not None:
+            rngs["dropout"] = dropout_rng
+
+        inputs = {"params": params or self.params}
+
+        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
+        # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
+        # it can be changed by FlaxBartAttention module
+        if past_key_values:
+            inputs["cache"] = past_key_values
+            mutable = ["cache"]
+        else:
+            mutable = False
+
+        def _decoder_forward(
+            module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
+        ):
+            projection_module = module._get_projection_module()
+            decoder_module = module._get_decoder_module()
+
+            # optionally project encoder_hidden_states
+            if projection_module is not None:
+                encoder_hidden_states = projection_module(encoder_hidden_states)
+
+            return decoder_module(
+                decoder_input_ids,
+                decoder_attention_mask,
+                decoder_position_ids,
+                encoder_hidden_states=encoder_hidden_states,
+                **kwargs,
+            )
+
+        outputs = self.module.apply(
+            inputs,
+            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            deterministic=not train,
+            rngs=rngs,
+            mutable=mutable,
+            method=_decoder_forward,
+        )
+
+        # add updated cache to model output
+        if past_key_values is not None and return_dict:
+            outputs, past = outputs
+            outputs["past_key_values"] = unfreeze(past["cache"])
+            return outputs
+        elif past_key_values is not None and not return_dict:
+            outputs, past = outputs
+            outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
+
+        return outputs
+
+    @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+    def __call__(
+        self,
+        input_ids: jnp.ndarray,
+        attention_mask: Optional[jnp.ndarray] = None,
+        decoder_input_ids: Optional[jnp.ndarray] = None,
+        decoder_attention_mask: Optional[jnp.ndarray] = None,
+        position_ids: Optional[jnp.ndarray] = None,
+        decoder_position_ids: Optional[jnp.ndarray] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        train: bool = False,
+        params: Optional[dict] = None,
+        dropout_rng: PRNGKey = None,
+    ):
+        r"""
+        Returns:
+
+        Examples:
+
+        ```python
+        >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer, GPT2Tokenizer
+
+        >>> # load a fine-tuned bert2gpt2 model
+        >>> model = FlaxEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16")
+        >>> # load input & output tokenizer
+        >>> tokenizer_input = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
+        >>> tokenizer_output = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
+
+        >>> article = '''Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members
+        >>> singing a racist chant. SAE's national chapter suspended the students,
+        >>> but University of Oklahoma President David Boren took it a step further,
+        >>> saying the university's affiliation with the fraternity is permanently done.'''
+
+        >>> input_ids = tokenizer_input(article, add_special_tokens=True, return_tensors="np").input_ids
+
+        >>> # use GPT2's eos_token as the pad as well as eos token
+        >>> model.config.eos_token_id = model.config.decoder.eos_token_id
+        >>> model.config.pad_token_id = model.config.eos_token_id
+
+        >>> sequences = model.generate(input_ids, num_beams=4, max_length=12).sequences
+
+        >>> summary = tokenizer_output.batch_decode(sequences, skip_special_tokens=True)[0]
+        >>> assert summary == "SAS Alpha Epsilon suspended Sigma Alpha Epsilon members"
+        ```
+        """
+
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+        # prepare encoder inputs
+        if attention_mask is None:
+            attention_mask = jnp.ones_like(input_ids)
+        if position_ids is None:
+            batch_size, sequence_length = input_ids.shape
+            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+        # prepare decoder inputs
+        if decoder_input_ids is None:
+            raise ValueError(
+                "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must"
+                " be specified as an input argument."
+            )
+        if decoder_attention_mask is None:
+            decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+        if decoder_position_ids is None:
+            batch_size, sequence_length = decoder_input_ids.shape
+            decoder_position_ids = jnp.broadcast_to(
+                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
+            )
+
+        # Handle any PRNG if needed
+        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
+
+        return self.module.apply(
+            {"params": params or self.params},
+            input_ids=jnp.array(input_ids, dtype="i4"),
+            attention_mask=jnp.array(attention_mask, dtype="i4"),
+            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+            position_ids=jnp.array(position_ids, dtype="i4"),
+            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            deterministic=not train,
+            rngs=rngs,
+        )
+
+    def prepare_inputs_for_generation(
+        self,
+        decoder_input_ids,
+        max_length,
+        attention_mask: Optional[jax.Array] = None,
+        decoder_attention_mask: Optional[jax.Array] = None,
+        encoder_outputs=None,
+        **kwargs,
+    ):
+        # initializing the cache
+        batch_size, seq_length = decoder_input_ids.shape
+
+        past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
+        # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+        # But since the decoder uses a causal mask, those positions are masked anyways.
+        # Thus we can create a single static attention_mask here, which is more efficient for compilation
+        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+        if decoder_attention_mask is not None:
+            decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
+            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
+        else:
+            decoder_position_ids = jnp.broadcast_to(
+                jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
+            )
+
+        return {
+            "past_key_values": past_key_values,
+            "encoder_outputs": encoder_outputs,
+            "encoder_attention_mask": attention_mask,
+            "decoder_attention_mask": extended_attention_mask,
+            "decoder_position_ids": decoder_position_ids,
+        }
+
+    def update_inputs_for_generation(self, model_outputs, model_kwargs):
+        model_kwargs["past_key_values"] = model_outputs.past_key_values
+        model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
+        return model_kwargs
+
+    @classmethod
+    def from_encoder_decoder_pretrained(
+        cls,
+        encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
+        decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
+        *model_args,
+        **kwargs,
+    ) -> FlaxPreTrainedModel:
+        r"""
+        Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
+        checkpoints.
+
+        Params:
+            encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*):
+                Information necessary to initiate the encoder. Can be either:
+
+                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+                    - A path to a *directory* containing model weights saved using
+                      [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+            decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`):
+                Information necessary to initiate the decoder. Can be either:
+
+                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+                    - A path to a *directory* containing model weights saved using
+                      [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+            model_args (remaining positional arguments, *optional*):
+                All remaining positional arguments will be passed to the underlying model's `__init__` method.
+
+            kwargs (remaining dictionary of keyword arguments, *optional*):
+                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+                `output_attentions=True`).
+
+                - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
+                - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
+                - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+                Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+        Example:
+
+        ```python
+        >>> from transformers import FlaxEncoderDecoderModel
+
+        >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
+        >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
+        >>> # saving model after fine-tuning
+        >>> model.save_pretrained("./bert2gpt2")
+        >>> # load fine-tuned model
+        >>> model = FlaxEncoderDecoderModel.from_pretrained("./bert2gpt2")
+        ```"""
+
+        kwargs_encoder = {
+            argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
+        }
+
+        kwargs_decoder = {
+            argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+        }
+
+        # remove encoder, decoder kwargs from kwargs
+        for key in kwargs_encoder:
+            del kwargs["encoder_" + key]
+        for key in kwargs_decoder:
+            del kwargs["decoder_" + key]
+
+        # Load and initialize the encoder and decoder
+        # The distinction between encoder and decoder at the model level is made
+        # by the value of the flag `is_decoder` that we need to set correctly.
+        encoder = kwargs_encoder.pop("model", None)
+        if encoder is None:
+            if encoder_pretrained_model_name_or_path is None:
+                raise ValueError(
+                    "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
+                    "to be defined."
+                )
+
+            if "config" not in kwargs_encoder:
+                encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
+                    encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
+                )
+                if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+                    logger.info(
+                        f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
+                        "from a decoder model. Cross-attention and causal mask are disabled."
+                    )
+                    encoder_config.is_decoder = False
+                    encoder_config.add_cross_attention = False
+
+                kwargs_encoder["config"] = encoder_config
+
+            encoder = FlaxAutoModel.from_pretrained(
+                encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
+            )
+
+        decoder = kwargs_decoder.pop("model", None)
+        if decoder is None:
+            if decoder_pretrained_model_name_or_path is None:
+                raise ValueError(
+                    "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
+                    "to be defined."
+                )
+
+            if "config" not in kwargs_decoder:
+                decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
+                    decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
+                )
+                if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
+                    logger.info(
+                        f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+                        f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+                        f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
+                    )
+                    decoder_config.is_decoder = True
+                    decoder_config.add_cross_attention = True
+
+                kwargs_decoder["config"] = decoder_config
+
+            if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
+                logger.warning(
+                    f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
+                    f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
+                    "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
+                    "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
+                    "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
+                )
+
+            decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
+
+        # instantiate config with corresponding kwargs
+        dtype = kwargs.pop("dtype", jnp.float32)
+        config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
+
+        # init model
+        model = cls(config, dtype=dtype)
+        model.params["encoder"] = encoder.params
+        model.params["decoder"] = decoder.params
+
+        return model
+
+
+__all__ = ["FlaxEncoderDecoderModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e5343d200499e1f3b8ba26f8d70924c2999a2fc
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
@@ -0,0 +1,661 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Classes to support TF Encoder-Decoder architectures"""
+
+from __future__ import annotations
+
+import inspect
+import re
+import warnings
+
+import numpy as np
+import tensorflow as tf
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
+from ...modeling_tf_utils import (
+    TFCausalLanguageModelingLoss,
+    TFModelInputType,
+    TFPreTrainedModel,
+    get_initializer,
+    keras,
+    unpack_inputs,
+)
+from ...tf_utils import shape_list
+from ...utils import (
+    ModelOutput,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    logging,
+    replace_return_docstrings,
+)
+from ..auto.configuration_auto import AutoConfig
+from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
+from .configuration_encoder_decoder import EncoderDecoderConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "EncoderDecoderConfig"
+
+DEPRECATION_WARNING = (
+    "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the"
+    " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if"
+    " fine-tuning a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the"
+    " labels, no need to pass them yourself anymore."
+)
+
+ENCODER_DECODER_START_DOCSTRING = r"""
+    This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
+    encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
+    [`~TFAutoModel.from_pretrained`] function and the decoder is loaded via [`~TFAutoModelForCausalLM.from_pretrained`]
+    function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream
+    generative task, like summarization.
+
+    The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
+    tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
+    Tasks](https://huggingface.co/papers/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
+    Zhou, Wei Li, Peter J. Liu.
+
+    After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models
+    (see the examples for more information).
+
+    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+    as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+    behavior.
+
+    Parameters:
+        config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ENCODER_DECODER_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`np.ndarray`, `tf.Tensor`, `list[tf.Tensor]` ``dict[str, tf.Tensor]` or `dict[str, np.ndarray]` and each example must have the shape `({0})`):
+            Indices of input sequence tokens in the vocabulary.
+
+            Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        decoder_input_ids (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Indices of decoder input sequence tokens in the vocabulary.
+
+            Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+
+            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+            `past_key_values`).
+
+            Provide for sequence to sequence training to the decoder. Indices can be obtained using
+            [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
+            details.
+        decoder_attention_mask (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+            be used by default.
+        encoder_outputs (`tuple(tuple(tf.Tensor)`, *optional*):
+            This tuple must consist of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+            `last_hidden_state` (`tf.Tensor` of shape `({0}, hidden_size)`) is a tensor of hidden-states at the output
+            of the last layer of the encoder. Used in the cross-attention of the decoder.
+        past_key_values (`tuple(tuple(tf.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+            `decoder_input_ids` of shape `({0})`.
+        inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+            model's internal embedding lookup matrix.
+        decoder_inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+            representation. This is useful if you want more control over how to convert `decoder_input_ids` indices
+            into associated vectors than the model's internal embedding lookup matrix.
+        labels (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+            Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0,
+            ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored
+            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+        use_cache (`bool`, *optional*):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+            `past_key_values`).
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            If set to `True`, the model will return a [`~utils.Seq2SeqLMOutput`] instead of a plain tuple.
+        training (`bool`, *optional*, defaults to `False`):
+            Whether or not to use the model in training mode (some modules like dropout modules have different
+            behaviors between training and evaluation).
+        kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
+
+            - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
+            - With a *decoder_* prefix which will be input as `**decoder_kwargs`` for the decoder forward function.
+"""
+
+
+def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
+    if pad_token_id is None:
+        raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
+    pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
+
+    if decoder_start_token_id is None:
+        raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
+    decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
+
+    start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
+    shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
+    # replace possible -100 values in labels by `pad_token_id`
+    shifted_input_ids = tf.where(
+        shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
+    )
+
+    # "Verify that `labels` has only positive values and -100"
+    assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
+
+    # Make sure the assertion op is called by wrapping the result in an identity no-op
+    with tf.control_dependencies([assert_gte0]):
+        shifted_input_ids = tf.identity(shifted_input_ids)
+
+    return shifted_input_ids
+
+
+@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
+class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
+    r"""
+    [`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
+    of the base model classes of the library as encoder and another one as decoder when created with the
+    [`~TFAutoModel.from_pretrained`] class method for the encoder and [`~TFAutoModelForCausalLM.from_pretrained`] class
+    method for the decoder.
+    """
+
+    config_class = EncoderDecoderConfig
+    base_model_prefix = "encoder_decoder"
+    load_weight_prefix = "tf_encoder_decoder_model"
+
+    def __init__(
+        self,
+        config: PretrainedConfig | None = None,
+        encoder: TFPreTrainedModel | None = None,
+        decoder: TFPreTrainedModel | None = None,
+    ):
+        if config is None and (encoder is None or decoder is None):
+            raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
+        if config is None:
+            config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
+        else:
+            if not isinstance(config, self.config_class):
+                raise ValueError(f"config: {config} has to be of type {self.config_class}")
+
+        if config.decoder.cross_attention_hidden_size is not None:
+            if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
+                raise ValueError(
+                    "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+                    f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+                    f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+                    " `config.encoder.hidden_size`."
+                )
+
+        # initialize with config
+        super().__init__(config)
+
+        if encoder is None:
+            encoder = TFAutoModel.from_config(config.encoder, name="encoder")
+
+        if decoder is None:
+            decoder = TFAutoModelForCausalLM.from_config(config.decoder, name="decoder")
+
+        self.encoder = encoder
+        self.decoder = decoder
+
+        if self.encoder.config.to_dict() != self.config.encoder.to_dict():
+            logger.warning(
+                f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
+                f" {self.config.encoder}"
+            )
+        if self.decoder.config.to_dict() != self.config.decoder.to_dict():
+            logger.warning(
+                f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+                f" {self.config.decoder}"
+            )
+
+        # make sure that the individual model's config refers to the shared config
+        # so that the updates to the config will be synced
+        self.encoder.config = self.config.encoder
+        self.decoder.config = self.config.decoder
+
+        # encoder outputs might need to be projected to different dimension for decoder
+        if (
+            self.encoder.config.hidden_size != self.decoder.config.hidden_size
+            and self.decoder.config.cross_attention_hidden_size is None
+        ):
+            self.enc_to_dec_proj = keras.layers.Dense(
+                units=self.decoder.config.hidden_size,
+                kernel_initializer=get_initializer(config.encoder.initializer_range),
+                name="enc_to_dec_proj",
+            )
+
+        if self.encoder.get_output_embeddings() is not None:
+            raise ValueError(
+                f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
+            )
+
+        decoder_signature = set(inspect.signature(self.decoder.call).parameters.keys())
+        if "encoder_hidden_states" not in decoder_signature:
+            raise ValueError(
+                "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
+                "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
+            )
+
+    def get_encoder(self):
+        return self.encoder
+
+    def get_input_embeddings(self):
+        return self.encoder.get_input_embeddings()
+
+    def get_output_embeddings(self):
+        return self.decoder.get_output_embeddings()
+
+    def set_output_embeddings(self, new_embeddings):
+        return self.decoder.set_output_embeddings(new_embeddings)
+
+    def tf_to_pt_weight_rename(self, tf_weight):
+        # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models
+        # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal.
+        # However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption
+        # here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's
+        # not the case, and I wasn't sure how else to go from the config to the correct MainLayer name!
+
+        # This override is only needed in the case where we're crossloading weights from PT. However, since weights are
+        # often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file.
+        # Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it
+        # or not.
+        encoder_model_type = self.config.encoder.model_type
+        if "encoder" in tf_weight and "decoder" not in tf_weight:
+            return (re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight),)
+        else:
+            return (tf_weight,)
+
+    @classmethod
+    def from_encoder_decoder_pretrained(
+        cls,
+        encoder_pretrained_model_name_or_path: str | None = None,
+        decoder_pretrained_model_name_or_path: str | None = None,
+        *model_args,
+        **kwargs,
+    ) -> TFPreTrainedModel:
+        r"""
+        Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
+        checkpoints.
+
+
+        Params:
+            encoder_pretrained_model_name_or_path (`str`, *optional*):
+                Information necessary to initiate the encoder. Can be either:
+
+                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+                    - A path to a *directory* containing model weights saved using
+                      [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+                    - A path or url to a *pytorch index checkpoint file* (e.g, `./pt_model/`). In this case,
+                      `encoder_from_pt` should be set to `True`.
+
+            decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
+                Information necessary to initiate the decoder. Can be either:
+
+                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+                    - A path to a *directory* containing model weights saved using
+                      [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+                    - A path or url to a *pytorch checkpoint file* (e.g, `./pt_model/`). In this case,
+                      `decoder_from_pt` should be set to `True`.
+
+            model_args (remaining positional arguments, *optional*):
+                All remaining positional arguments will be passed to the underlying model's `__init__` method.
+
+            kwargs (remaining dictionary of keyword arguments, *optional*):
+                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+                `output_attentions=True`).
+
+                - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
+                - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
+                - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+                Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+        Example:
+
+        ```python
+        >>> from transformers import TFEncoderDecoderModel
+
+        >>> # initialize a bert2gpt2 from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized
+        >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "openai-community/gpt2")
+        >>> # saving model after fine-tuning
+        >>> model.save_pretrained("./bert2gpt2")
+        >>> # load fine-tuned model
+        >>> model = TFEncoderDecoderModel.from_pretrained("./bert2gpt2")
+        ```"""
+
+        kwargs_encoder = {
+            argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
+        }
+
+        kwargs_decoder = {
+            argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+        }
+
+        # remove encoder, decoder kwargs from kwargs
+        for key in kwargs_encoder:
+            del kwargs["encoder_" + key]
+        for key in kwargs_decoder:
+            del kwargs["decoder_" + key]
+
+        # Load and initialize the encoder and decoder
+        # The distinction between encoder and decoder at the model level is made
+        # by the value of the flag `is_decoder` that we need to set correctly.
+        encoder = kwargs_encoder.pop("model", None)
+        if encoder is None:
+            if encoder_pretrained_model_name_or_path is None:
+                raise ValueError(
+                    "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
+                    "to be defined."
+                )
+
+            if "config" not in kwargs_encoder:
+                encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
+                if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+                    logger.info(
+                        f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
+                        "from a decoder model. Cross-attention and causal mask are disabled."
+                    )
+                    encoder_config.is_decoder = False
+                    encoder_config.add_cross_attention = False
+
+                kwargs_encoder["config"] = encoder_config
+
+            kwargs_encoder["name"] = "encoder"
+            kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
+            encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
+
+        decoder = kwargs_decoder.pop("model", None)
+        if decoder is None:
+            if decoder_pretrained_model_name_or_path is None:
+                raise ValueError(
+                    "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
+                    "to be defined."
+                )
+
+            if "config" not in kwargs_decoder:
+                decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
+                if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
+                    logger.info(
+                        f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+                        f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+                        f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
+                    )
+                    decoder_config.is_decoder = True
+                    decoder_config.add_cross_attention = True
+
+                kwargs_decoder["config"] = decoder_config
+
+            if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
+                logger.warning(
+                    f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
+                    f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
+                    "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
+                    "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
+                    "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
+                )
+
+            kwargs_decoder["name"] = "decoder"
+            kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
+            decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
+
+        # Make sure these 2 `keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
+        if encoder.name != "encoder":
+            raise ValueError("encoder model must be created with the name `encoder`.")
+        if decoder.name != "decoder":
+            raise ValueError("decoder model must be created with the name `decoder`.")
+
+        # instantiate config with corresponding kwargs
+        config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
+        return cls(encoder=encoder, decoder=decoder, config=config)
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+    @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+    def call(
+        self,
+        input_ids: TFModelInputType | None = None,
+        attention_mask: np.ndarray | tf.Tensor | None = None,
+        decoder_input_ids: np.ndarray | tf.Tensor | None = None,
+        decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+        encoder_outputs: np.ndarray | tf.Tensor | None = None,
+        past_key_values: tuple[tuple[tf.Tensor]] | None = None,
+        inputs_embeds: np.ndarray | tf.Tensor | None = None,
+        decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
+        labels: np.ndarray | tf.Tensor | None = None,
+        use_cache: bool | None = None,
+        output_attentions: bool | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        training: bool = False,
+        **kwargs,
+    ) -> TFSeq2SeqLMOutput | tuple[tf.Tensor]:
+        r"""
+        Returns:
+
+        Examples:
+
+        ```python
+        >>> from transformers import TFEncoderDecoderModel, BertTokenizer
+
+        >>> # initialize a bert2gpt2 from a pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
+        >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
+
+        >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
+
+        >>> # forward
+        >>> input_ids = tokenizer.encode(
+        ...     "Hello, my dog is cute", add_special_tokens=True, return_tensors="tf"
+        ... )  # Batch size 1
+        >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
+
+        >>> # training
+        >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)
+        >>> loss, logits = outputs.loss, outputs.logits
+
+        >>> # save and load from pretrained
+        >>> model.save_pretrained("bert2gpt2")
+        >>> model = TFEncoderDecoderModel.from_pretrained("bert2gpt2")
+
+        >>> # generation
+        >>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.bos_token_id)
+        ```"""
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
+
+        kwargs_decoder = {
+            argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+        }
+
+        # Let the user be responsible for the expected format.
+        if encoder_outputs is not None:
+            if return_dict and not isinstance(encoder_outputs, ModelOutput):
+                raise ValueError(
+                    "If `return_dict=True` and `encoder_outputs` is provided, it should be an instance of "
+                    f"`ModelOutput`. Got an instance {type(encoder_outputs)} for `encoder_outputs`."
+                )
+
+        if encoder_outputs is None:
+            encoder_inputs = {
+                "input_ids": input_ids,
+                "attention_mask": attention_mask,
+                "inputs_embeds": inputs_embeds,
+                "output_attentions": output_attentions,
+                "output_hidden_states": output_hidden_states,
+                "return_dict": return_dict,
+                "training": training,
+            }
+
+            # Add arguments to encoder from `kwargs_encoder`
+            encoder_inputs.update(kwargs_encoder)
+
+            # Handle the case where the inputs are passed as a single dict which contains `labels`.
+            # The `labels` shouldn't be passed to `self.encoder` below, because it is a based model without this
+            # parameter (otherwise, an error occurs when `input_processing` is called inside `self.encoder.call()`).
+            if "labels" in encoder_inputs:
+                labels = encoder_inputs.pop("labels")
+
+            # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`.
+            if "decoder_input_ids" in encoder_inputs:
+                decoder_input_ids = encoder_inputs.pop("decoder_input_ids")
+            # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`.
+            if "decoder_attention_mask" in encoder_inputs:
+                decoder_attention_mask = encoder_inputs.pop("decoder_attention_mask")
+
+            encoder_outputs = self.encoder(**encoder_inputs)
+
+        encoder_hidden_states = encoder_outputs[0]
+
+        # optionally project encoder_hidden_states
+        if (
+            self.encoder.config.hidden_size != self.decoder.config.hidden_size
+            and self.decoder.config.cross_attention_hidden_size is None
+        ):
+            encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+        if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
+            decoder_input_ids = shift_tokens_right(
+                labels, self.config.pad_token_id, self.config.decoder_start_token_id
+            )
+
+        decoder_inputs = {
+            "input_ids": decoder_input_ids,
+            "attention_mask": decoder_attention_mask,
+            "encoder_hidden_states": encoder_hidden_states,
+            "encoder_attention_mask": attention_mask,
+            "inputs_embeds": decoder_inputs_embeds,
+            "output_attentions": output_attentions,
+            "output_hidden_states": output_hidden_states,
+            "use_cache": use_cache,
+            "past_key_values": past_key_values,
+            "return_dict": return_dict,
+            "training": training,
+        }
+
+        # Add arguments to decoder from `kwargs_decoder`
+        decoder_inputs.update(kwargs_decoder)
+
+        decoder_outputs = self.decoder(**decoder_inputs)
+
+        logits = decoder_outputs[0]
+
+        # Compute loss independent from decoder (as some shift the logits inside them)
+        loss = None
+        if labels is not None:
+            warnings.warn(DEPRECATION_WARNING, FutureWarning)
+            loss = self.hf_compute_loss(labels, logits)
+
+        if not return_dict:
+            past_key_values = None
+            if use_cache:
+                past_key_values = decoder_outputs[1]
+            # The starting index of the remaining elements in `decoder_outputs`
+            start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
+
+            if not isinstance(encoder_outputs, tuple):
+                encoder_outputs = encoder_outputs.to_tuple()
+            output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
+            output = tuple(x for x in output if x is not None)
+            return output
+
+        return TFSeq2SeqLMOutput(
+            loss=loss,
+            logits=decoder_outputs.logits,
+            past_key_values=decoder_outputs.past_key_values,
+            decoder_hidden_states=decoder_outputs.hidden_states,
+            decoder_attentions=decoder_outputs.attentions,
+            cross_attentions=decoder_outputs.cross_attentions,
+            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+            encoder_hidden_states=encoder_outputs.hidden_states,
+            encoder_attentions=encoder_outputs.attentions,
+        )
+
+    def prepare_inputs_for_generation(
+        self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
+    ):
+        decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
+        decoder_attention_mask = decoder_inputs.get("attention_mask", None)
+        past_key_values = decoder_inputs.get("past_key_values")
+        if past_key_values is None:
+            past_key_values = decoder_inputs.get("past")  # e.g. on TF GPT2
+        input_dict = {
+            "input_ids": None,  # needs to be passed to make Keras.layer.__call__ happy
+            "attention_mask": attention_mask,
+            "decoder_attention_mask": decoder_attention_mask,
+            "decoder_input_ids": decoder_inputs["input_ids"],
+            # TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
+            "encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
+            "past_key_values": past_key_values,
+            "use_cache": use_cache,
+        }
+        return input_dict
+
+    def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
+        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+    def resize_token_embeddings(self, *args, **kwargs):
+        raise NotImplementedError(
+            "Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported.Please use the"
+            " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
+            " model.decoder.resize_token_embeddings(...))"
+        )
+
+    def _reorder_cache(self, past, beam_idx):
+        # apply decoder cache reordering here
+        return self.decoder._reorder_cache(past, beam_idx)
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "enc_to_dec_proj", None) is not None:
+            with tf.name_scope(self.enc_to_dec_proj.name):
+                self.enc_to_dec_proj.build([None, None, self.encoder.config.hidden_size])
+        if getattr(self, "encoder", None) is not None:
+            with tf.name_scope(self.encoder.name):
+                self.encoder.build(None)
+        if getattr(self, "decoder", None) is not None:
+            with tf.name_scope(self.decoder.name):
+                self.decoder.build(None)
+
+
+__all__ = ["TFEncoderDecoderModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8eac54d6ddcbdae2b8ca3771ae5540522f6f29da
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_esm import *
+    from .modeling_esm import *
+    from .modeling_esmfold import *
+    from .modeling_tf_esm import *
+    from .tokenization_esm import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/configuration_esm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/configuration_esm.py
new file mode 100644
index 0000000000000000000000000000000000000000..fabfb4ebd6d34a7f212af5e74a90c18d4a038156
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/configuration_esm.py
@@ -0,0 +1,365 @@
+# coding=utf-8
+# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ESM model configuration"""
+
+from dataclasses import asdict, dataclass
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+# TODO Update this
+
+
+class EsmConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`ESMModel`]. It is used to instantiate a ESM model
+    according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+    defaults will yield a similar configuration to that of the ESM
+    [facebook/esm-1b](https://huggingface.co/facebook/esm-1b) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        vocab_size (`int`, *optional*):
+            Vocabulary size of the ESM model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`ESMModel`].
+        mask_token_id (`int`, *optional*):
+            The index of the mask token in the vocabulary. This must be included in the config because of the
+            "mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens.
+        pad_token_id (`int`, *optional*):
+            The index of the padding token in the vocabulary. This must be included in the config because certain parts
+            of the ESM code use this instead of the attention mask.
+        hidden_size (`int`, *optional*, defaults to 768):
+            Dimensionality of the encoder layers and the pooler layer.
+        num_hidden_layers (`int`, *optional*, defaults to 12):
+            Number of hidden layers in the Transformer encoder.
+        num_attention_heads (`int`, *optional*, defaults to 12):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        intermediate_size (`int`, *optional*, defaults to 3072):
+            Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+        hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+            The dropout ratio for the attention probabilities.
+        max_position_embeddings (`int`, *optional*, defaults to 1026):
+            The maximum sequence length that this model might ever be used with. Typically set this to something large
+            just in case (e.g., 512 or 1024 or 2048).
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+            The epsilon used by the layer normalization layers.
+        position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+            Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query", "rotary"`.
+            For positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+            [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
+            For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+            with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
+        is_decoder (`bool`, *optional*, defaults to `False`):
+            Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models). Only
+            relevant if `config.is_decoder=True`.
+        emb_layer_norm_before (`bool`, *optional*):
+            Whether to apply layer normalization after embeddings but before the main stem of the network.
+        token_dropout (`bool`, defaults to `False`):
+            When this is enabled, masked tokens are treated as if they had been dropped out by input dropout.
+
+    Examples:
+
+    ```python
+    >>> from transformers import EsmModel, EsmConfig
+
+    >>> # Initializing a ESM facebook/esm-1b style configuration
+    >>> configuration = EsmConfig(vocab_size=33)
+
+    >>> # Initializing a model from the configuration
+    >>> model = EsmModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "esm"
+
+    def __init__(
+        self,
+        vocab_size=None,
+        mask_token_id=None,
+        pad_token_id=None,
+        hidden_size=768,
+        num_hidden_layers=12,
+        num_attention_heads=12,
+        intermediate_size=3072,
+        hidden_dropout_prob=0.1,
+        attention_probs_dropout_prob=0.1,
+        max_position_embeddings=1026,
+        initializer_range=0.02,
+        layer_norm_eps=1e-12,
+        position_embedding_type="absolute",
+        use_cache=True,
+        emb_layer_norm_before=None,
+        token_dropout=False,
+        is_folding_model=False,
+        esmfold_config=None,
+        vocab_list=None,
+        **kwargs,
+    ):
+        super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
+
+        self.vocab_size = vocab_size
+        self.hidden_size = hidden_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.intermediate_size = intermediate_size
+        self.hidden_dropout_prob = hidden_dropout_prob
+        self.attention_probs_dropout_prob = attention_probs_dropout_prob
+        self.max_position_embeddings = max_position_embeddings
+        self.initializer_range = initializer_range
+        self.layer_norm_eps = layer_norm_eps
+        self.position_embedding_type = position_embedding_type
+        self.use_cache = use_cache
+        self.emb_layer_norm_before = emb_layer_norm_before
+        self.token_dropout = token_dropout
+        self.is_folding_model = is_folding_model
+        if is_folding_model:
+            if esmfold_config is None:
+                logger.info("No esmfold_config supplied for folding model, using default values.")
+                esmfold_config = EsmFoldConfig()
+            elif isinstance(esmfold_config, dict):
+                esmfold_config = EsmFoldConfig(**esmfold_config)
+            self.esmfold_config = esmfold_config
+            if vocab_list is None:
+                logger.warning("No vocab_list supplied for folding model, assuming the ESM-2 vocabulary!")
+                self.vocab_list = get_default_vocab_list()
+            else:
+                self.vocab_list = vocab_list
+        else:
+            self.esmfold_config = None
+            self.vocab_list = None
+        if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False):
+            raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!")
+
+    def to_dict(self):
+        """
+        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+        Returns:
+            `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+        """
+        output = super().to_dict()
+        if isinstance(self.esmfold_config, EsmFoldConfig):
+            output["esmfold_config"] = self.esmfold_config.to_dict()
+        return output
+
+
+@dataclass
+class EsmFoldConfig:
+    esm_type: Optional[str] = None
+    fp16_esm: bool = True
+    use_esm_attn_map: bool = False
+    esm_ablate_pairwise: bool = False
+    esm_ablate_sequence: bool = False
+    esm_input_dropout: float = 0
+
+    embed_aa: bool = True
+    bypass_lm: bool = False
+
+    lddt_head_hid_dim: int = 128
+    trunk: "TrunkConfig" = None
+
+    def __post_init__(self):
+        if self.trunk is None:
+            self.trunk = TrunkConfig()
+        elif isinstance(self.trunk, dict):
+            self.trunk = TrunkConfig(**self.trunk)
+
+    def to_dict(self):
+        """
+        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+        Returns:
+            `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+        """
+        output = asdict(self)
+        output["trunk"] = self.trunk.to_dict()
+        return output
+
+
+@dataclass
+class TrunkConfig:
+    num_blocks: int = 48
+    sequence_state_dim: int = 1024
+    pairwise_state_dim: int = 128
+    sequence_head_width: int = 32
+    pairwise_head_width: int = 32
+    position_bins: int = 32
+    dropout: float = 0
+    layer_drop: float = 0
+    cpu_grad_checkpoint: bool = False
+    max_recycles: int = 4
+    chunk_size: Optional[int] = 128
+    structure_module: "StructureModuleConfig" = None
+
+    def __post_init__(self):
+        if self.structure_module is None:
+            self.structure_module = StructureModuleConfig()
+        elif isinstance(self.structure_module, dict):
+            self.structure_module = StructureModuleConfig(**self.structure_module)
+
+        if self.max_recycles <= 0:
+            raise ValueError(f"`max_recycles` should be positive, got {self.max_recycles}.")
+        if self.sequence_state_dim % self.sequence_state_dim != 0:
+            raise ValueError(
+                "`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got"
+                f" {self.sequence_state_dim} and {self.sequence_state_dim}."
+            )
+        if self.pairwise_state_dim % self.pairwise_state_dim != 0:
+            raise ValueError(
+                "`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got"
+                f" {self.pairwise_state_dim} and {self.pairwise_state_dim}."
+            )
+
+        sequence_num_heads = self.sequence_state_dim // self.sequence_head_width
+        pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width
+
+        if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width:
+            raise ValueError(
+                "`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got"
+                f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}."
+            )
+        if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width:
+            raise ValueError(
+                "`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got"
+                f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}."
+            )
+        if self.pairwise_state_dim % 2 != 0:
+            raise ValueError(f"`pairwise_state_dim` should be even, got {self.pairwise_state_dim}.")
+
+        if self.dropout >= 0.4:
+            raise ValueError(f"`dropout` should not be greater than 0.4, got {self.dropout}.")
+
+    def to_dict(self):
+        """
+        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+        Returns:
+            `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+        """
+        output = asdict(self)
+        output["structure_module"] = self.structure_module.to_dict()
+        return output
+
+
+@dataclass
+class StructureModuleConfig:
+    """
+    Args:
+        sequence_dim:
+            Single representation channel dimension
+        pairwise_dim:
+            Pair representation channel dimension
+        ipa_dim:
+            IPA hidden channel dimension
+        resnet_dim:
+            Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
+        num_heads_ipa:
+            Number of IPA heads
+        num_qk_points:
+            Number of query/key points to generate during IPA
+        num_v_points:
+            Number of value points to generate during IPA
+        dropout_rate:
+            Dropout rate used throughout the layer
+        num_blocks:
+            Number of structure module blocks
+        num_transition_layers:
+            Number of layers in the single representation transition (Alg. 23 lines 8-9)
+        num_resnet_blocks:
+            Number of blocks in the angle resnet
+        num_angles:
+            Number of angles to generate in the angle resnet
+        trans_scale_factor:
+            Scale of single representation transition hidden dimension
+        epsilon:
+            Small number used in angle resnet normalization
+        inf:
+            Large number used for attention masking
+    """
+
+    sequence_dim: int = 384
+    pairwise_dim: int = 128
+    ipa_dim: int = 16
+    resnet_dim: int = 128
+    num_heads_ipa: int = 12
+    num_qk_points: int = 4
+    num_v_points: int = 8
+    dropout_rate: float = 0.1
+    num_blocks: int = 8
+    num_transition_layers: int = 1
+    num_resnet_blocks: int = 2
+    num_angles: int = 7
+    trans_scale_factor: int = 10
+    epsilon: float = 1e-8
+    inf: float = 1e5
+
+    def to_dict(self):
+        return asdict(self)
+
+
+def get_default_vocab_list():
+    return (
+        "",
+        "",
+        "",
+        "",
+        "L",
+        "A",
+        "G",
+        "V",
+        "S",
+        "E",
+        "R",
+        "T",
+        "I",
+        "D",
+        "P",
+        "K",
+        "Q",
+        "N",
+        "F",
+        "Y",
+        "M",
+        "H",
+        "W",
+        "C",
+        "X",
+        "B",
+        "U",
+        "Z",
+        "O",
+        ".",
+        "-",
+        "",
+        "",
+    )
+
+
+__all__ = ["EsmConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py
new file mode 100644
index 0000000000000000000000000000000000000000..63d9344188cc83bd8ac4719db78def849afbda7f
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py
@@ -0,0 +1,1058 @@
+# coding=utf-8
+# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ESM model."""
+
+import math
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+    BaseModelOutputWithCrossAttentions,
+    BaseModelOutputWithPoolingAndCrossAttentions,
+    MaskedLMOutput,
+    SequenceClassifierOutput,
+    TokenClassifierOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+from ...utils.generic import OutputRecorder, check_model_inputs
+from .configuration_esm import EsmConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+def rotate_half(x):
+    x1, x2 = x.chunk(2, dim=-1)
+    return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(x, cos, sin):
+    cos = cos[:, :, : x.shape[-2], :]
+    sin = sin[:, :, : x.shape[-2], :]
+
+    return (x * cos) + (rotate_half(x) * sin)
+
+
+def gelu(x):
+    """
+    This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results.
+    """
+    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
+
+
+def symmetrize(x):
+    "Make layer symmetric in final two dimensions, used for contact prediction."
+    return x + x.transpose(-1, -2)
+
+
+def average_product_correct(x):
+    "Perform average product correct, used for contact prediction."
+    a1 = x.sum(-1, keepdims=True)
+    a2 = x.sum(-2, keepdims=True)
+    a12 = x.sum((-1, -2), keepdims=True)
+
+    avg = a1 * a2
+    avg.div_(a12)  # in-place to reduce memory
+    normalized = x - avg
+    return normalized
+
+
+class RotaryEmbedding(torch.nn.Module):
+    """
+    Rotary position embeddings based on those in
+    [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
+    matrices which depend on their relative positions.
+    """
+
+    inv_freq: torch.Tensor  # fix linting for `register_buffer`
+
+    def __init__(self, dim: int):
+        super().__init__()
+        # Generate and save the inverse frequency buffer (non trainable)
+        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
+        self.register_buffer("inv_freq", inv_freq)
+
+        self._seq_len_cached = None
+        self._cos_cached = None
+        self._sin_cached = None
+
+    def _update_cos_sin_tables(self, x, seq_dimension=2):
+        seq_len = x.shape[seq_dimension]
+
+        # Reset the tables if the sequence length has changed,
+        # or if we're on a new device (possibly due to tracing for instance)
+        if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
+            self._seq_len_cached = seq_len
+            t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
+            freqs = torch.outer(t, self.inv_freq)
+            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+
+            self._cos_cached = emb.cos()[None, None, :, :]
+            self._sin_cached = emb.sin()[None, None, :, :]
+
+        return self._cos_cached, self._sin_cached
+
+    def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
+
+        return (
+            apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
+            apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
+        )
+
+
+class EsmContactPredictionHead(nn.Module):
+    """Performs symmetrization, apc, and computes a logistic regression on the output features"""
+
+    def __init__(
+        self,
+        in_features: int,
+        bias=True,
+        eos_idx: int = 2,
+    ):
+        super().__init__()
+        self.in_features = in_features
+        self.eos_idx = eos_idx
+        self.regression = nn.Linear(in_features, 1, bias)
+        self.activation = nn.Sigmoid()
+
+    def forward(self, tokens, attentions):
+        # remove eos token attentions
+        eos_mask = tokens.ne(self.eos_idx).to(attentions)
+        eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
+        attentions = attentions * eos_mask[:, None, None, :, :]
+        attentions = attentions[..., :-1, :-1]
+        # remove cls token attentions
+        attentions = attentions[..., 1:, 1:]
+        batch_size, layers, heads, seqlen, _ = attentions.size()
+        attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
+
+        # features: batch x channels x tokens x tokens (symmetric)
+        attentions = attentions.to(
+            self.regression.weight.device
+        )  # attentions always float32, may need to convert to float16
+        attentions = average_product_correct(symmetrize(attentions))
+        attentions = attentions.permute(0, 2, 3, 1)
+        return self.activation(self.regression(attentions).squeeze(3))
+
+
+class EsmEmbeddings(nn.Module):
+    """
+    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+
+        if config.emb_layer_norm_before:
+            self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        else:
+            self.layer_norm = None
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+        self.register_buffer(
+            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+        )
+
+        self.padding_idx = config.pad_token_id
+        if self.position_embedding_type == "absolute":
+            self.position_embeddings = nn.Embedding(
+                config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
+            )
+        self.token_dropout = config.token_dropout
+        self.mask_token_id = config.mask_token_id
+
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        position_ids=None,
+        inputs_embeds=None,
+    ):
+        if position_ids is None:
+            if input_ids is not None:
+                # Create the position ids from the input token ids. Any padded tokens remain padded.
+                position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
+            else:
+                position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an
+        # embedding_scale factor here.
+        embeddings = inputs_embeds
+
+        # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
+        # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
+        # masked tokens are treated as if they were selected for input dropout and zeroed out.
+        # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
+        # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
+        # This is analogous to the way that dropout layers scale down outputs during evaluation when not
+        # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
+        if self.token_dropout and input_ids is not None:
+            embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
+            mask_ratio_train = 0.15 * 0.8  # Hardcoded as the ratio used in all ESM model training runs
+            src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1]
+            mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
+            embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
+                embeddings.dtype
+            )
+
+        if self.position_embedding_type == "absolute":
+            position_embeddings = self.position_embeddings(position_ids)
+            embeddings = embeddings + position_embeddings
+
+        if self.layer_norm is not None:
+            embeddings = self.layer_norm(embeddings)
+        if attention_mask is not None:
+            embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
+        # Matt: I think this line was copied incorrectly from BERT, disabling it for now.
+        # embeddings = self.dropout(embeddings)
+        return embeddings
+
+    def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+        """
+        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+
+        Args:
+            inputs_embeds: torch.Tensor
+
+        Returns: torch.Tensor
+        """
+        input_shape = inputs_embeds.size()[:-1]
+        sequence_length = input_shape[1]
+
+        position_ids = torch.arange(
+            self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
+        )
+        return position_ids.unsqueeze(0).expand(input_shape)
+
+
+def eager_attention_forward(
+    module: nn.Module,
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attention_mask: Optional[torch.Tensor],
+    scaling: float,
+    dropout: float = 0.0,
+    head_mask: Optional[torch.Tensor] = None,
+    **kwargs: Unpack[TransformersKwargs],
+):
+    # ESM applies relative position embeddings and we don't copy from Llama
+    attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+
+    if hasattr(module, "position_embedding_type") and module.position_embedding_type in [
+        "relative_key",
+        "relative_key_query",
+    ]:
+        seq_length = query.shape[2]
+        position_ids_l = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(-1, 1)
+        position_ids_r = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(1, -1)
+        distance = position_ids_l - position_ids_r
+        positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
+        positional_embedding = positional_embedding.to(dtype=query.dtype)  # fp16 compatibility
+
+        if module.position_embedding_type == "relative_key":
+            relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
+        elif module.position_embedding_type == "relative_key_query":
+            relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
+            relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
+            relative_position_scores = relative_position_scores_query + relative_position_scores_key
+
+        attn_weights = attn_weights + relative_position_scores
+
+    if attention_mask is not None:
+        causal_mask = attention_mask[:, :, :, : key.shape[-2]]
+        attn_weights = attn_weights + causal_mask
+
+    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+    if head_mask is not None:
+        attn_weights = attn_weights * head_mask
+
+    attn_output = torch.matmul(attn_weights, value)
+    attn_output = attn_output.transpose(1, 2).contiguous()
+
+    return attn_output, attn_weights
+
+
+class EsmSelfAttention(nn.Module):
+    def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False):
+        super().__init__()
+        self.config = config
+
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+            raise ValueError(
+                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+                f"heads ({config.num_attention_heads})"
+            )
+
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+        self.query = nn.Linear(config.hidden_size, self.all_head_size)
+        self.key = nn.Linear(config.hidden_size, self.all_head_size)
+        self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+        self.dropout = config.attention_probs_dropout_prob
+        self.position_embedding_type = position_embedding_type or getattr(
+            config, "position_embedding_type", "absolute"
+        )
+        self.rotary_embeddings = None
+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+            self.max_position_embeddings = config.max_position_embeddings
+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+        elif self.position_embedding_type == "rotary":
+            self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
+
+        self.scaling = 1.0  # For BC we apply scaling before RoPE
+        self.is_decoder = config.is_decoder
+        self.layer_idx = layer_idx
+        self.is_causal = self.is_decoder and not is_cross_attention
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> tuple[torch.Tensor]:
+        batch_size, seq_length = hidden_states.shape[:-1]
+        hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
+
+        query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
+
+        is_cross_attention = encoder_hidden_states is not None
+        current_states = encoder_hidden_states if is_cross_attention else hidden_states
+        attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
+        key_layer = self.key(current_states).view(hidden_shape).transpose(1, 2)
+        value_layer = self.value(current_states).view(hidden_shape).transpose(1, 2)
+
+        # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
+        # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
+        # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
+        # ESM code and fix rotary embeddings.
+        query_layer = query_layer * self.attention_head_size**-0.5
+
+        if self.position_embedding_type == "rotary":
+            query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
+
+        attention_interface: Callable = eager_attention_forward
+        if self.config._attn_implementation != "eager":
+            if self.position_embedding_type in ["relative_key", "relative_key_query"]:
+                raise ValueError(
+                    f"ESM {self.config._attn_implementation} attention does not support {self.position_embedding_type} embeddings. "
+                    "Set attention explicitly to 'eager' with `model.set_attn_implementation('eager')`"
+                )
+            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+        attn_output, attn_weights = attention_interface(
+            self,
+            query_layer,
+            key_layer,
+            value_layer,
+            attention_mask,
+            dropout=0.0 if not self.training else self.dropout,
+            scaling=self.scaling,
+            head_mask=head_mask,
+            **kwargs,
+        )
+
+        attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
+        return attn_output, attn_weights
+
+
+class EsmSelfOutput(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states, input_tensor):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = hidden_states + input_tensor
+        return hidden_states
+
+
+class EsmAttention(nn.Module):
+    def __init__(self, config, layer_idx=None, is_cross_attention=False):
+        super().__init__()
+        self.self = EsmSelfAttention(config, layer_idx=layer_idx, is_cross_attention=is_cross_attention)
+        self.output = EsmSelfOutput(config)
+        self.pruned_heads = set()
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+    def prune_heads(self, heads):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+        )
+
+        # Prune linear layers
+        self.self.query = prune_linear_layer(self.self.query, index)
+        self.self.key = prune_linear_layer(self.self.key, index)
+        self.self.value = prune_linear_layer(self.self.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        **kwargs: Unpack[TransformersKwargs],
+    ):
+        hidden_states_ln = self.LayerNorm(hidden_states)
+        attn_output, _ = self.self(
+            hidden_states_ln,
+            attention_mask=attention_mask,
+            head_mask=head_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            **kwargs,
+        )
+        attn_output = self.output(attn_output, hidden_states)
+        return attn_output
+
+
+class EsmIntermediate(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = gelu(hidden_states)
+        return hidden_states
+
+
+class EsmOutput(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states, input_tensor):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = hidden_states + input_tensor
+        return hidden_states
+
+
+class EsmLayer(GradientCheckpointingLayer):
+    def __init__(self, config):
+        super().__init__()
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.seq_len_dim = 1
+        self.attention = EsmAttention(config)
+        self.is_decoder = config.is_decoder
+        self.add_cross_attention = config.add_cross_attention
+        if self.add_cross_attention:
+            if not self.is_decoder:
+                raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
+            self.crossattention = EsmAttention(config, is_cross_attention=True)
+        self.intermediate = EsmIntermediate(config)
+        self.output = EsmOutput(config)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        **kwargs: Unpack[TransformersKwargs],
+    ):
+        attention_output = self.attention(
+            hidden_states,
+            attention_mask=attention_mask,
+            head_mask=head_mask,
+            **kwargs,
+        )
+
+        if self.is_decoder and encoder_hidden_states is not None:
+            if not hasattr(self, "crossattention"):
+                raise AttributeError(
+                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
+                    " with cross-attention layers by setting `config.add_cross_attention=True`"
+                )
+
+            attention_output = self.crossattention(
+                attention_output,
+                attention_mask=attention_mask,
+                head_mask=head_mask,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                **kwargs,
+            )
+
+        layer_output = self.feed_forward_chunk(attention_output)
+        return layer_output
+
+    def feed_forward_chunk(self, attention_output):
+        attention_output_ln = self.LayerNorm(attention_output)
+        intermediate_output = self.intermediate(attention_output_ln)
+        layer_output = self.output(intermediate_output, attention_output)
+        return layer_output
+
+
+class EsmEncoder(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
+        self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.gradient_checkpointing = False
+
+    @can_return_tuple
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        **kwargs: Unpack[TransformersKwargs],
+    ):
+        for i, layer_module in enumerate(self.layer):
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+            hidden_states = layer_module(
+                hidden_states,
+                attention_mask=attention_mask,
+                head_mask=layer_head_mask,
+                encoder_hidden_states=encoder_hidden_states,
+                encoder_attention_mask=encoder_attention_mask,
+                **kwargs,
+            )
+
+        if self.emb_layer_norm_after:
+            hidden_states = self.emb_layer_norm_after(hidden_states)
+
+        return BaseModelOutputWithCrossAttentions(last_hidden_state=hidden_states)
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPooler
+class EsmPooler(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.activation = nn.Tanh()
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(first_token_tensor)
+        pooled_output = self.activation(pooled_output)
+        return pooled_output
+
+
+@auto_docstring
+class EsmPreTrainedModel(PreTrainedModel):
+    config: EsmConfig
+    base_model_prefix = "esm"
+    supports_gradient_checkpointing = True
+    accepts_loss_kwargs = False
+    _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
+    _keys_to_ignore_on_load_unexpected = ["position_embeddings.weight"]
+    _supports_flash_attn = True
+    _supports_sdpa = True
+    _supports_flex_attn = True
+    _supports_attention_backend = True
+
+    _can_record_outputs = {
+        "hidden_states": EsmLayer,
+        "attentions": [OutputRecorder(EsmSelfAttention, index=1, layer_name="attention")],
+        "cross_attentions": [
+            OutputRecorder(EsmSelfAttention, index=1, layer_name="crossattention"),
+        ],
+    }
+
+    # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        if isinstance(module, nn.Linear):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, EsmLMHead):
+            module.bias.data.zero_()
+
+    def get_output_embeddings(self):
+        # NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
+        # See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
+        return None
+
+
+@auto_docstring
+class EsmModel(EsmPreTrainedModel):
+    """
+
+    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+    cross-attention is added between the self-attention layers, following the architecture described in [Attention is
+    all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+    To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+    to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+    `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+    """
+
+    def __init__(self, config, add_pooling_layer=True):
+        r"""
+        add_pooling_layer (bool, *optional*, defaults to `True`):
+            Whether to add a pooling layer
+        """
+        super().__init__(config)
+        self.config = config
+
+        self.embeddings = EsmEmbeddings(config)
+        self.encoder = EsmEncoder(config)
+
+        self.pooler = EsmPooler(config) if add_pooling_layer else None
+
+        self.contact_head = EsmContactPredictionHead(
+            in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
+        )
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embeddings.word_embeddings
+
+    def set_input_embeddings(self, value):
+        self.embeddings.word_embeddings = value
+
+    def _prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    @check_model_inputs
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.Tensor] = None,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+        r"""
+        input_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`):
+            Indices of input sequence tokens in the vocabulary.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        position_ids (`torch.LongTensor` of shape `((batch_size, sequence_length))`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.max_position_embeddings - 1]`.
+
+            [What are position IDs?](../glossary#position-ids)
+        inputs_embeds (`torch.FloatTensor` of shape `((batch_size, sequence_length), hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+            model's internal embedding lookup matrix.
+        """
+        if (input_ids is None) ^ (inputs_embeds is not None):
+            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embeddings(
+                input_ids=input_ids,
+                position_ids=position_ids,
+            )
+
+        if self.config._attn_implementation != "flash_attention_2":
+            batch_size, seq_length = inputs_embeds.shape[:-1]
+            if attention_mask is None:
+                attention_mask = torch.ones(((batch_size, seq_length)), device=inputs_embeds.device)
+
+            attention_mask: torch.Tensor = self.get_extended_attention_mask(
+                attention_mask, input_shape=(batch_size, seq_length)
+            )
+
+        # If a 2D or 3D attention mask is provided for the cross-attention
+        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+        if self.config.is_decoder and encoder_hidden_states is not None:
+            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+            if encoder_attention_mask is None:
+                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
+            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+        else:
+            encoder_extended_attention_mask = None
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+        encoder_outputs = self.encoder(
+            inputs_embeds,
+            attention_mask=attention_mask,
+            head_mask=head_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_extended_attention_mask,
+            **kwargs,
+        )
+        sequence_output = encoder_outputs[0]
+        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+        return BaseModelOutputWithPoolingAndCrossAttentions(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+        )
+
+    def predict_contacts(self, tokens, attention_mask):
+        attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
+        attns = torch.stack(attns, dim=1)  # Matches the original model layout
+        # In the original model, attentions for padding tokens are completely zeroed out.
+        # This makes no difference most of the time because the other tokens won't attend to them,
+        # but it does for the contact prediction task, which takes attentions as input,
+        # so we have to mimic that here.
+        attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
+        attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
+        return self.contact_head(tokens, attns)
+
+
+@auto_docstring
+class EsmForMaskedLM(EsmPreTrainedModel):
+    _tied_weights_keys = ["lm_head.decoder.weight"]
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        if config.is_decoder:
+            logger.warning(
+                "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
+                "bi-directional self-attention."
+            )
+
+        self.esm = EsmModel(config, add_pooling_layer=False)
+        self.lm_head = EsmLMHead(config)
+
+        self.init_weights()
+
+        self.post_init()
+
+    def get_output_embeddings(self):
+        return self.lm_head.decoder
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head.decoder = new_embeddings
+
+    @can_return_tuple
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.Tensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> Union[tuple, MaskedLMOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+        """
+
+        outputs = self.esm(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            **kwargs,
+        )
+        sequence_output = outputs[0]
+        prediction_scores = self.lm_head(sequence_output)
+
+        masked_lm_loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+
+            labels = labels.to(prediction_scores.device)
+            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+        return MaskedLMOutput(
+            loss=masked_lm_loss,
+            logits=prediction_scores,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+    def predict_contacts(self, tokens, attention_mask):
+        return self.esm.predict_contacts(tokens, attention_mask=attention_mask)
+
+
+class EsmLMHead(nn.Module):
+    """ESM Head for masked language modeling."""
+
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+    def forward(self, features, **kwargs):
+        x = self.dense(features)
+        x = gelu(x)
+        x = self.layer_norm(x)
+
+        # project back to size of vocabulary with bias
+        x = self.decoder(x) + self.bias
+        return x
+
+
+@auto_docstring(
+    custom_intro="""
+    ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+    output) e.g. for GLUE tasks.
+    """
+)
+class EsmForSequenceClassification(EsmPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+        self.config = config
+
+        self.esm = EsmModel(config, add_pooling_layer=False)
+        self.classifier = EsmClassificationHead(config)
+
+        self.init_weights()
+
+        self.post_init()
+
+    @can_return_tuple
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> Union[tuple, SequenceClassifierOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+
+        outputs = self.esm(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            **kwargs,
+        )
+        sequence_output = outputs[0]
+        logits = self.classifier(sequence_output)
+
+        loss = None
+        if labels is not None:
+            labels = labels.to(logits.device)
+
+            if self.config.problem_type is None:
+                if self.num_labels == 1:
+                    self.config.problem_type = "regression"
+                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+                    self.config.problem_type = "single_label_classification"
+                else:
+                    self.config.problem_type = "multi_label_classification"
+
+            if self.config.problem_type == "regression":
+                loss_fct = MSELoss()
+                if self.num_labels == 1:
+                    loss = loss_fct(logits.squeeze(), labels.squeeze())
+                else:
+                    loss = loss_fct(logits, labels)
+            elif self.config.problem_type == "single_label_classification":
+                loss_fct = CrossEntropyLoss()
+                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+            elif self.config.problem_type == "multi_label_classification":
+                loss_fct = BCEWithLogitsLoss()
+                loss = loss_fct(logits, labels)
+
+        return SequenceClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+@auto_docstring
+class EsmForTokenClassification(EsmPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+
+        self.esm = EsmModel(config, add_pooling_layer=False)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+        self.init_weights()
+
+        self.post_init()
+
+    @can_return_tuple
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> Union[tuple, TokenClassifierOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+        """
+
+        outputs = self.esm(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            **kwargs,
+        )
+
+        sequence_output = outputs[0]
+
+        sequence_output = self.dropout(sequence_output)
+        logits = self.classifier(sequence_output)
+
+        loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+
+            labels = labels.to(logits.device)
+            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+        return TokenClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+class EsmClassificationHead(nn.Module):
+    """Head for sentence-level classification tasks."""
+
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+    def forward(self, features, **kwargs):
+        x = features[:, 0, :]  # take  token (equiv. to [CLS])
+        x = self.dropout(x)
+        x = self.dense(x)
+        x = torch.tanh(x)
+        x = self.dropout(x)
+        x = self.out_proj(x)
+        return x
+
+
+def create_position_ids_from_input_ids(input_ids, padding_idx):
+    """
+    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
+    are ignored. This is modified from fairseq's `utils.make_positions`.
+
+    Args:
+        x: torch.Tensor x:
+
+    Returns: torch.Tensor
+    """
+    # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+    mask = input_ids.ne(padding_idx).int()
+    incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
+    return incremental_indices.long() + padding_idx
+
+
+__all__ = [
+    "EsmForMaskedLM",
+    "EsmForSequenceClassification",
+    "EsmForTokenClassification",
+    "EsmModel",
+    "EsmPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_esmfold.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_esmfold.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bc1f0dbdc701991a4109ddbc617eb1b3769c6a1
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_esmfold.py
@@ -0,0 +1,2309 @@
+# coding=utf-8
+# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+import sys
+from collections.abc import Sequence
+from dataclasses import dataclass
+from functools import partial
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import LayerNorm
+
+from ...integrations.deepspeed import is_deepspeed_available
+from ...modeling_outputs import ModelOutput
+from ...utils import (
+    ContextManagers,
+    auto_docstring,
+    is_scipy_available,
+    logging,
+)
+from .modeling_esm import EsmModel, EsmPreTrainedModel
+from .openfold_utils import (
+    OFProtein,
+    Rigid,
+    Rotation,
+    atom14_to_atom37,
+    chunk_layer,
+    compute_predicted_aligned_error,
+    compute_tm,
+    frames_and_literature_positions_to_atom14_pos,
+    make_atom14_masks,
+    residue_constants,
+    to_pdb,
+    torsion_angles_to_frames,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+    custom_intro="""
+    Output type of [`EsmForProteinFoldingOutput`].
+    """
+)
+class EsmForProteinFoldingOutput(ModelOutput):
+    r"""
+    frames (`torch.FloatTensor`):
+        Output frames.
+    sidechain_frames (`torch.FloatTensor`):
+        Output sidechain frames.
+    unnormalized_angles (`torch.FloatTensor`):
+        Predicted unnormalized backbone and side chain torsion angles.
+    angles (`torch.FloatTensor`):
+        Predicted backbone and side chain torsion angles.
+    positions (`torch.FloatTensor`):
+        Predicted positions of the backbone and side chain atoms.
+    states (`torch.FloatTensor`):
+        Hidden states from the protein folding trunk.
+    s_s (`torch.FloatTensor`):
+        Per-residue embeddings derived by concatenating the hidden states of each layer of the ESM-2 LM stem.
+    s_z (`torch.FloatTensor`):
+        Pairwise residue embeddings.
+    distogram_logits (`torch.FloatTensor`):
+        Input logits to the distogram used to compute residue distances.
+    lm_logits (`torch.FloatTensor`):
+        Logits output by the ESM-2 protein language model stem.
+    aatype (`torch.FloatTensor`):
+        Input amino acids (AlphaFold2 indices).
+    atom14_atom_exists (`torch.FloatTensor`):
+        Whether each atom exists in the atom14 representation.
+    residx_atom14_to_atom37 (`torch.FloatTensor`):
+        Mapping between atoms in the atom14 and atom37 representations.
+    residx_atom37_to_atom14 (`torch.FloatTensor`):
+        Mapping between atoms in the atom37 and atom14 representations.
+    atom37_atom_exists (`torch.FloatTensor`):
+        Whether each atom exists in the atom37 representation.
+    residue_index (`torch.FloatTensor`):
+        The index of each residue in the protein chain. Unless internal padding tokens are used, this will just be
+        a sequence of integers from 0 to `sequence_length`.
+    lddt_head (`torch.FloatTensor`):
+        Raw outputs from the lddt head used to compute plddt.
+    plddt (`torch.FloatTensor`):
+        Per-residue confidence scores. Regions of low confidence may indicate areas where the model's prediction is
+        uncertain, or where the protein structure is disordered.
+    ptm_logits (`torch.FloatTensor`):
+        Raw logits used for computing ptm.
+    ptm (`torch.FloatTensor`):
+        TM-score output representing the model's high-level confidence in the overall structure.
+    aligned_confidence_probs (`torch.FloatTensor`):
+        Per-residue confidence scores for the aligned structure.
+    predicted_aligned_error (`torch.FloatTensor`):
+        Predicted error between the model's prediction and the ground truth.
+    max_predicted_aligned_error (`torch.FloatTensor`):
+        Per-sample maximum predicted error.
+    """
+
+    frames: Optional[torch.FloatTensor] = None
+    sidechain_frames: Optional[torch.FloatTensor] = None
+    unnormalized_angles: Optional[torch.FloatTensor] = None
+    angles: Optional[torch.FloatTensor] = None
+    positions: Optional[torch.FloatTensor] = None
+    states: Optional[torch.FloatTensor] = None
+    s_s: Optional[torch.FloatTensor] = None
+    s_z: Optional[torch.FloatTensor] = None
+    distogram_logits: Optional[torch.FloatTensor] = None
+    lm_logits: Optional[torch.FloatTensor] = None
+    aatype: Optional[torch.FloatTensor] = None
+    atom14_atom_exists: Optional[torch.FloatTensor] = None
+    residx_atom14_to_atom37: Optional[torch.FloatTensor] = None
+    residx_atom37_to_atom14: Optional[torch.FloatTensor] = None
+    atom37_atom_exists: Optional[torch.FloatTensor] = None
+    residue_index: Optional[torch.FloatTensor] = None
+    lddt_head: Optional[torch.FloatTensor] = None
+    plddt: Optional[torch.FloatTensor] = None
+    ptm_logits: Optional[torch.FloatTensor] = None
+    ptm: Optional[torch.FloatTensor] = None
+    aligned_confidence_probs: Optional[torch.FloatTensor] = None
+    predicted_aligned_error: Optional[torch.FloatTensor] = None
+    max_predicted_aligned_error: Optional[torch.FloatTensor] = None
+
+
+def is_fp16_enabled(device_type):
+    # Autocast world
+    autocast_dtype = (
+        torch.get_autocast_dtype(device_type)
+        if hasattr(torch, "get_autocast_dtype")
+        else torch.get_autocast_gpu_dtype()
+    )
+    fp16_enabled = autocast_dtype == torch.float16
+    fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
+
+    return fp16_enabled
+
+
+def is_deepspeed_initialized():
+    if is_deepspeed_available():
+        return False
+    else:
+        try:
+            import deepspeed
+
+            # This is not available in all DeepSpeed versions.
+            return deepspeed.utils.is_initialized()
+        except Exception:
+            return False
+
+
+def collate_dense_tensors(samples: list[torch.Tensor], pad_v: float = 0) -> torch.Tensor:
+    """
+    Takes a list of tensors with the following dimensions:
+        [(d_11, ..., d_1K),
+         (d_21, ..., d_2K), ..., (d_N1, ..., d_NK)]
+    and stack + pads them into a single tensor of:
+    (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
+    """
+    if len(samples) == 0:
+        return torch.Tensor()
+    if len({x.dim() for x in samples}) != 1:
+        raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}")
+    (device,) = tuple({x.device for x in samples})  # assumes all on same device
+    max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
+    result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device)
+    result.fill_(pad_v)
+    for i in range(len(samples)):
+        result_i = result[i]
+        t = samples[i]
+        result_i[tuple(slice(0, k) for k in t.shape)] = t
+    return result
+
+
+def flatten_final_dims(t: torch.Tensor, no_dims: int):
+    return t.reshape(t.shape[:-no_dims] + (-1,))
+
+
+def permute_final_dims(tensor: torch.Tensor, inds: list[int]):
+    zero_index = -1 * len(inds)
+    first_inds = list(range(len(tensor.shape[:zero_index])))
+    return tensor.permute(first_inds + [zero_index + i for i in inds])
+
+
+def dict_multimap(fn, dicts):
+    first = dicts[0]
+    new_dict = {}
+    for k, v in first.items():
+        all_v = [d[k] for d in dicts]
+        if isinstance(v, dict):
+            new_dict[k] = dict_multimap(fn, all_v)
+        else:
+            new_dict[k] = fn(all_v)
+
+    return new_dict
+
+
+def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
+    shape = weights.shape
+    scale = scale / max(1, shape[1])
+
+    if not is_scipy_available():
+        logger.warning(
+            "This init requires scipy, but scipy was not found, default to an approximation that might not be"
+            " equivalent."
+        )
+        std = math.sqrt(scale)
+        torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std)
+
+    else:
+        from scipy.stats import truncnorm
+
+        std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1)
+        samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel())
+        samples = np.reshape(samples, shape)
+        weights.copy_(torch.tensor(samples, device=weights.device))
+
+
+def ipa_point_weights_init_(weights):
+    with torch.no_grad():
+        softplus_inverse_1 = 0.541324854612918
+        weights.fill_(softplus_inverse_1)
+
+
+class EsmFoldLinear(nn.Linear):
+    """
+    A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear.
+
+    Implements the initializers in 1.11.4, plus some additional ones found in the code.
+    """
+
+    def __init__(
+        self,
+        in_dim: int,
+        out_dim: int,
+        bias: bool = True,
+        init: str = "default",
+        init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
+    ):
+        """
+        Args:
+            in_dim:
+                The final dimension of inputs to the layer
+            out_dim:
+                The final dimension of layer outputs
+            bias:
+                Whether to learn an additive bias. True by default
+            init:
+                The initializer to use. Choose from:
+
+                "default": LeCun fan-in truncated normal initialization "relu": He initialization w/ truncated normal
+                distribution "glorot": Fan-average Glorot uniform initialization "gating": Weights=0, Bias=1 "normal":
+                Normal initialization with std=1/sqrt(fan_in) "final": Weights=0, Bias=0
+
+                Overridden by init_fn if the latter is not None.
+            init_fn:
+                A custom initializer taking weight and bias as inputs. Overrides init if not None.
+        """
+        super().__init__(in_dim, out_dim, bias=bias)
+
+        if bias:
+            with torch.no_grad():
+                self.bias.fill_(0)
+        self.init = init
+        self.init_fn = init_fn
+
+        if init not in ["default", "relu", "glorot", "gating", "normal", "final"]:
+            raise ValueError("Invalid init string.")
+
+
+class EsmFoldLayerNorm(nn.Module):
+    def __init__(self, c_in, eps=1e-5):
+        super().__init__()
+
+        self.c_in = (c_in,)
+        self.eps = eps
+
+        self.weight = nn.Parameter(torch.ones(c_in))
+        self.bias = nn.Parameter(torch.zeros(c_in))
+
+    def forward(self, x):
+        d = x.dtype
+        if d is torch.bfloat16 and not is_deepspeed_initialized():
+            with torch.autocast(device_type="cuda", enabled=False):
+                out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps)
+        else:
+            out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps)
+
+        return out
+
+
+@torch.jit.ignore
+def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
+    """
+    Softmax, but without automatic casting to fp32 when the input is of type bfloat16
+    """
+    d = t.dtype
+    if d is torch.bfloat16 and not is_deepspeed_initialized():
+        with torch.autocast(device_type="cuda", enabled=False):
+            s = torch.nn.functional.softmax(t, dim=dim)
+    else:
+        s = torch.nn.functional.softmax(t, dim=dim)
+
+    return s
+
+
+class EsmFoldAttention(nn.Module):
+    """
+    Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors.
+    """
+
+    def __init__(
+        self,
+        c_q: int,
+        c_k: int,
+        c_v: int,
+        c_hidden: int,
+        no_heads: int,
+        gating: bool = True,
+    ):
+        """
+        Args:
+            c_q:
+                Input dimension of query data
+            c_k:
+                Input dimension of key data
+            c_v:
+                Input dimension of value data
+            c_hidden:
+                Per-head hidden dimension
+            no_heads:
+                Number of attention heads
+            gating:
+                Whether the output should be gated using query data
+        """
+        super().__init__()
+
+        self.c_q = c_q
+        self.c_k = c_k
+        self.c_v = c_v
+        self.c_hidden = c_hidden
+        self.no_heads = no_heads
+        self.gating = gating
+
+        # DISCREPANCY: c_hidden is not the per-head channel dimension, as
+        # stated in the supplement, but the overall channel dimension.
+
+        self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot")
+        self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot")
+        self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot")
+        self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final")
+
+        self.linear_g = None
+        if self.gating:
+            self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating")
+
+        self.sigmoid = nn.Sigmoid()
+
+    def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        # [*, Q/K/V, H * C_hidden]
+        q = self.linear_q(q_x)
+        k = self.linear_k(kv_x)
+        v = self.linear_v(kv_x)
+
+        # [*, Q/K, H, C_hidden]
+        q = q.view(q.shape[:-1] + (self.no_heads, -1))
+        k = k.view(k.shape[:-1] + (self.no_heads, -1))
+        v = v.view(v.shape[:-1] + (self.no_heads, -1))
+
+        # [*, H, Q/K, C_hidden]
+        q = q.transpose(-2, -3)
+        k = k.transpose(-2, -3)
+        v = v.transpose(-2, -3)
+
+        q /= math.sqrt(self.c_hidden)
+
+        return q, k, v
+
+    def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
+        if self.linear_g is not None:
+            g = self.sigmoid(self.linear_g(q_x))
+
+            # [*, Q, H, C_hidden]
+            g = g.view(g.shape[:-1] + (self.no_heads, -1))
+            o = o * g
+
+        # [*, Q, H * C_hidden]
+        o = flatten_final_dims(o, 2)
+
+        # [*, Q, C_q]
+        o = self.linear_o(o)
+
+        return o
+
+    def forward(
+        self,
+        q_x: torch.Tensor,
+        kv_x: torch.Tensor,
+        biases: Optional[list[torch.Tensor]] = None,
+        use_memory_efficient_kernel: bool = False,
+        use_lma: bool = False,
+        lma_q_chunk_size: int = 1024,
+        lma_kv_chunk_size: int = 4096,
+        use_flash: bool = False,
+        flash_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """
+        Args:
+            q_x:
+                [*, Q, C_q] query data
+            kv_x:
+                [*, K, C_k] key data
+            biases:
+                List of biases that broadcast to [*, H, Q, K]
+            use_memory_efficient_kernel:
+                Whether to use a custom memory-efficient attention kernel. This should be the default choice for most.
+                If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead
+            use_lma:
+                Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a
+                stock PyTorch implementation is used instead
+            lma_q_chunk_size:
+                Query chunk size (for LMA)
+            lma_kv_chunk_size:
+                Key/Value chunk size (for LMA)
+        Returns
+            [*, Q, C_q] attention update
+        """
+        if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
+            raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided")
+
+        if use_flash and biases is not None:
+            raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead")
+
+        attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
+        if sum(attn_options) > 1:
+            raise ValueError("Choose at most one alternative attention algorithm")
+
+        if biases is None:
+            biases = []
+
+        # [*, H, Q/K, C_hidden]
+        query, key, value = self._prep_qkv(q_x, kv_x)
+        key = permute_final_dims(key, (1, 0))
+
+        # [*, H, Q, K]
+        output = torch.matmul(query, key)
+        for b in biases:
+            output += b
+        output = softmax_no_cast(output, -1)
+
+        # [*, H, Q, C_hidden]
+        output = torch.matmul(output, value)
+        output = output.transpose(-2, -3)
+        output = self._wrap_up(output, q_x)
+
+        return output
+
+
+class EsmFoldTriangleAttention(nn.Module):
+    def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9):
+        """
+        Args:
+            c_in:
+                Input channel dimension
+            c_hidden:
+                Overall hidden channel dimension (not per-head)
+            no_heads:
+                Number of attention heads
+        """
+        super().__init__()
+
+        self.c_in = c_in
+        self.c_hidden = c_hidden
+        self.no_heads = no_heads
+        self.starting = starting
+        self.inf = inf
+
+        self.layer_norm = LayerNorm(self.c_in)
+
+        self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal")
+
+        self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
+
+    @torch.jit.ignore
+    def _chunk(
+        self,
+        x: torch.Tensor,
+        biases: list[torch.Tensor],
+        chunk_size: int,
+        use_memory_efficient_kernel: bool = False,
+        use_lma: bool = False,
+        inplace_safe: bool = False,
+    ) -> torch.Tensor:
+        "triangle! triangle!"
+        mha_inputs = {
+            "q_x": x,
+            "kv_x": x,
+            "biases": biases,
+        }
+
+        return chunk_layer(
+            partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma),
+            mha_inputs,
+            chunk_size=chunk_size,
+            no_batch_dims=len(x.shape[:-2]),
+            _out=x if inplace_safe else None,
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        chunk_size: Optional[int] = None,
+        use_memory_efficient_kernel: bool = False,
+        use_lma: bool = False,
+        inplace_safe: bool = False,
+    ) -> torch.Tensor:
+        """
+        Args:
+            x:
+                [*, I, J, C_in] input tensor (e.g. the pair representation)
+        Returns:
+            [*, I, J, C_in] output tensor
+        """
+        if mask is None:
+            # [*, I, J]
+            mask = x.new_ones(
+                x.shape[:-1],
+            )
+
+        if not self.starting:
+            x = x.transpose(-2, -3)
+            mask = mask.transpose(-1, -2)
+
+        # [*, I, J, C_in]
+        x = self.layer_norm(x)
+
+        # [*, I, 1, 1, J]
+        mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
+
+        # [*, H, I, J]
+        triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
+
+        # [*, 1, H, I, J]
+        triangle_bias = triangle_bias.unsqueeze(-4)
+
+        biases = [mask_bias, triangle_bias]
+
+        if chunk_size is not None:
+            x = self._chunk(
+                x,
+                biases,
+                chunk_size,
+                use_memory_efficient_kernel=use_memory_efficient_kernel,
+                use_lma=use_lma,
+                inplace_safe=inplace_safe,
+            )
+        else:
+            x = self.mha(
+                q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma
+            )
+
+        if not self.starting:
+            x = x.transpose(-2, -3)
+
+        return x
+
+
+class EsmFoldTriangleMultiplicativeUpdate(nn.Module):
+    """
+    Implements Algorithms 11 and 12.
+    """
+
+    def __init__(self, config, _outgoing=True):
+        super().__init__()
+        c_hidden = config.pairwise_state_dim
+        self._outgoing = _outgoing
+
+        self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden)
+        self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
+        self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden)
+        self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
+        self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
+        self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final")
+
+        self.layer_norm_in = LayerNorm(c_hidden)
+        self.layer_norm_out = LayerNorm(c_hidden)
+
+        self.sigmoid = nn.Sigmoid()
+
+    def _combine_projections(
+        self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: Optional[int] = None
+    ) -> torch.Tensor:
+        if self._outgoing:
+            a = permute_final_dims(a, (2, 0, 1))
+            b = permute_final_dims(b, (2, 1, 0))
+        else:
+            a = permute_final_dims(a, (2, 1, 0))
+            b = permute_final_dims(b, (2, 0, 1))
+
+        if _inplace_chunk_size is not None:
+            # To be replaced by torch vmap
+            for i in range(0, a.shape[-3], _inplace_chunk_size):
+                a_chunk = a[..., i : i + _inplace_chunk_size, :, :]
+                b_chunk = b[..., i : i + _inplace_chunk_size, :, :]
+                a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul(
+                    a_chunk,
+                    b_chunk,
+                )
+
+            p = a
+        else:
+            p = torch.matmul(a, b)
+
+        return permute_final_dims(p, (1, 2, 0))
+
+    def _inference_forward(
+        self,
+        z: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        inplace_chunk_size: Optional[int] = None,
+        with_add: bool = True,
+    ):
+        """
+        Args:
+            z:
+                A [*, N, N, C_z] pair representation
+            mask:
+                A [*, N, N] pair mask
+            inplace_chunk_size:
+                Size of chunks used in the main computation. Increase to trade memory for speed.
+            with_add:
+                If True, z is overwritten with (z + update). Otherwise, it is overwritten with (update).
+        Returns:
+            A reference to the overwritten z
+
+        More memory-efficient, inference-only version of the forward function. Uses in-place operations, fusion of the
+        addition that happens after this module in the Evoformer, a smidge of recomputation, and a cache of overwritten
+        values to lower peak memory consumption of this module from 5x the size of the input tensor z to 2.5x its size.
+        Useful for inference on extremely long sequences.
+
+        It works as follows. We will make reference to variables used in the default forward implementation below.
+        Naively, triangle multiplication attention requires the manifestation of 5 tensors the size of z: 1) z, the
+        "square" input tensor, 2) a, the first projection of z, 3) b, the second projection of b, 4) g, a z-sized mask,
+        and 5) a z-sized tensor for intermediate computations. For large N, this is prohibitively expensive; for
+        N=4000, for example, z is more than 8GB alone. To avoid this problem, we compute b, g, and all intermediate
+        tensors in small chunks, noting that the chunks required to compute a chunk of the output depend only on the
+        tensor a and corresponding vertical and horizontal chunks of z. This suggests an algorithm that loops over
+        pairs of chunks of z: hereafter "columns" and "rows" of z, even though each "column" and "row" in fact contains
+        inplace_chunk_size contiguous true columns and rows of z. Writing output chunks to a new tensor would bring
+        total memory consumption down to 3x the size of z. However, more memory can be saved by writing output chunks
+        directly to z in-place. WLOG, we choose to write output chunks vertically, overwriting the ith "column" of z at
+        the end of the ith iteration of the main loop. Despite this overwriting, the ith column is always one column
+        ahead of previously overwritten columns and can be recovered directly from z. After the first iteration,
+        however, the ith row of z is always at least partially overwritten. For this reason, we introduce the z-cache,
+        a tensor one-half the size of z. The z-cache initially contains the left half (2nd and 3rd quadrants) of z. For
+        0 < i < N/2, the missing left part of the ith row of z is recovered from this cache at the beginning of the ith
+        iteration. Once i exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th quadrants of z instead.
+        Though the 3rd quadrant of the original z is entirely overwritten at this point, it can be recovered from the
+        z-cache itself. Thereafter, the ith row of z can be recovered in its entirety from the reoriented z-cache.
+        After the final iteration, z has been completely overwritten and contains the triangular multiplicative update.
+        If with_add is True, it instead contains the sum of z and the triangular multiplicative update. In either case,
+        peak memory consumption is just 2.5x the size of z, disregarding memory used for chunks and other small
+        variables.
+        """
+        if mask is None:
+            mask = z.new_ones(z.shape[:-1])
+
+        mask = mask.unsqueeze(-1)
+
+        def compute_projection_helper(pair, mask, a=True):
+            if a:
+                linear_g = self.linear_a_g
+                linear_p = self.linear_a_p
+            else:
+                linear_g = self.linear_b_g
+                linear_p = self.linear_b_p
+
+            pair = self.layer_norm_in(pair)
+            p = linear_g(pair)
+            p.sigmoid_()
+            p *= linear_p(pair)
+            p *= mask
+            p = permute_final_dims(p, (2, 0, 1))
+            return p
+
+        def compute_projection(pair, mask, a=True, chunked=True):
+            need_transpose = self._outgoing ^ a
+            if not chunked:
+                p = compute_projection_helper(pair, mask, a)
+                if need_transpose:
+                    p = p.transpose(-1, -2)
+            else:
+                # This computation is chunked so as not to exceed our 2.5x
+                # budget with a large intermediate tensor
+                linear_g = self.linear_a_g if a else self.linear_b_g
+                c = linear_g.bias.shape[-1]
+                out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1]
+                p = pair.new_zeros(out_shape)
+                for i in range(0, pair.shape[-3], inplace_chunk_size):
+                    pair_chunk = pair[..., i : i + inplace_chunk_size, :, :]
+                    pair_chunk = compute_projection_helper(
+                        pair[..., i : i + inplace_chunk_size, :, :],
+                        mask[..., i : i + inplace_chunk_size, :, :],
+                        a,
+                    )
+                    if need_transpose:
+                        pair_chunk = pair_chunk.transpose(-1, -2)
+                        p[..., i : i + inplace_chunk_size] = pair_chunk
+                    else:
+                        p[..., i : i + inplace_chunk_size, :] = pair_chunk
+
+                    del pair_chunk
+
+            return p
+
+        # We start by fully manifesting a. In addition to the input, this
+        # brings total memory consumption to 2x z (disregarding size of chunks)
+        # [*, N, N, c]
+        a = compute_projection(z, mask, True, chunked=True)
+
+        if inplace_chunk_size is not None:
+            n = a.shape[-1]
+            half_n = n // 2 + n % 2
+            row_dim = -3
+            col_dim = -2
+            b_chunk_dim = row_dim if self._outgoing else col_dim
+
+            def empty_slicer(t):
+                return [slice(None) for _ in t.shape]
+
+            def slice_tensor(t, start, end, dim):
+                # Slices start:end from the dim dimension of t
+                s = empty_slicer(t)
+                s[dim] = slice(start, end)
+                return t[s]
+
+            def flip_z_cache_(z_cache, z):
+                # "Reorient" the z_cache (see below), filling it with quadrants
+                # 3---recovered from the z_cache---and 4---recovered from z---
+                # of the input tensor z.
+                quadrant_3 = slice_tensor(z_cache, half_n, None, row_dim)
+                z_cache = z_cache.transpose(row_dim, col_dim)
+
+                # If n is odd, we need to shrink the z_cache by one row
+                z_cache = z_cache[..., : (n // 2), :, :]
+
+                # Move the 3rd quadrant of z into the
+                first_half_slicer = empty_slicer(z_cache)
+                first_half_slicer[col_dim] = slice(0, half_n)
+                z_cache[first_half_slicer] = quadrant_3
+
+                # Get the fourth quadrant of z
+                quadrant_4 = slice_tensor(z, half_n, None, row_dim)
+                quadrant_4 = slice_tensor(quadrant_4, half_n, None, col_dim)
+
+                # Insert said quadrant into the rotated z-cache
+                quadrant_3_slicer = empty_slicer(z_cache)
+                quadrant_3_slicer[col_dim] = slice(half_n, None)
+
+                z_cache[quadrant_3_slicer] = quadrant_4
+
+                return z_cache
+
+            # Initialize the z cache to the left half of z.
+            z_cache_shape = list(z.shape)
+            z_cache_shape[col_dim] = half_n
+            z_cache = z.new_zeros(z_cache_shape)
+            z_cache_slicer = empty_slicer(z_cache)
+            z_cache_slicer[col_dim] = slice(0, half_n)
+            z_cache.copy_(z[z_cache_slicer])
+            z_cache_rotated = False
+
+            # We need to reorient the z-cache at the halfway point, and we
+            # don't want a single chunk to straddle that point. We contract one
+            # of the chunks in the middle to address that problem.
+            i_range = list(range(0, half_n, inplace_chunk_size))
+            initial_offsets = [i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])]
+            after_half = list(range(half_n, n, inplace_chunk_size))
+            after_half_offsets = [inplace_chunk_size for _ in after_half]
+            combined_range_with_offsets = zip(i_range + after_half, initial_offsets + after_half_offsets)
+            for i, offset in combined_range_with_offsets:
+                if not z_cache_rotated and i >= half_n:
+                    z_cache = flip_z_cache_(z_cache, z)
+                    z_cache_rotated = True
+
+                z_chunk_b = slice_tensor(z, i, i + offset, b_chunk_dim)
+                mask_chunk = slice_tensor(mask, i, i + offset, b_chunk_dim)
+
+                z_chunk_b = z_chunk_b.clone()
+                if b_chunk_dim == col_dim:
+                    z_chunk_b = slice_tensor(z, i, i + offset, col_dim)
+                else:  # b_chunk_dim == row_dim
+                    # In this case, the b-dimension (b_chunk_dim) is partially
+                    # overwritten at the end of each iteration. We need to
+                    # restore the missing component from the z-cache.
+                    if not z_cache_rotated:
+                        z_chunk_slicer = empty_slicer(z_chunk_b)
+                        z_chunk_slicer[col_dim] = slice(0, half_n)
+                        z_chunk_b[z_chunk_slicer] = slice_tensor(z_cache, i, i + offset, row_dim)
+                    else:
+                        z_cache_offset = i - half_n
+                        z_chunk_b = slice_tensor(z_cache, z_cache_offset, z_cache_offset + offset, row_dim)
+
+                b_chunk = compute_projection(z_chunk_b, mask_chunk, a=False, chunked=False)
+                del z_chunk_b
+
+                x_chunk = torch.matmul(a, b_chunk)
+                x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
+                x_chunk = self.layer_norm_out(x_chunk)
+                x_chunk = self.linear_z(x_chunk)
+
+                # The g dimension (col_dim) is parallel to and ahead of the
+                # overwrites in z. We can extract the g chunk normally.
+                z_chunk_g = slice_tensor(z, i, i + offset, col_dim)
+                g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g))
+                g_chunk.sigmoid_()
+                del z_chunk_g
+
+                x_chunk *= g_chunk
+
+                # Write the columns into z in-place
+                z_slicer = empty_slicer(z)
+                z_slicer[col_dim] = slice(i, i + offset)
+                if with_add:
+                    z[z_slicer] += x_chunk
+                else:
+                    z[z_slicer] = x_chunk
+        else:
+            b = compute_projection(z, mask, False, False)
+            x = torch.matmul(a, b)
+            x = self.layer_norm_out(x)
+            x = self.linear_z(x)
+            g = self.linear_g(z)
+            g.sigmoid_()
+            x *= g
+            if with_add:
+                z += x
+            else:
+                z = x
+
+        return z
+
+    def forward(
+        self,
+        z: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        inplace_safe: bool = False,
+        _add_with_inplace: bool = False,
+        _inplace_chunk_size: Optional[int] = 256,
+    ) -> torch.Tensor:
+        """
+        Args:
+            x:
+                [*, N_res, N_res, C_z] input tensor
+            mask:
+                [*, N_res, N_res] input mask
+        Returns:
+            [*, N_res, N_res, C_z] output tensor
+        """
+        if inplace_safe:
+            x = self._inference_forward(
+                z,
+                mask,
+                inplace_chunk_size=_inplace_chunk_size,
+                with_add=_add_with_inplace,
+            )
+            return x
+
+        if mask is None:
+            mask = z.new_ones(z.shape[:-1])
+
+        mask = mask.unsqueeze(-1)
+
+        z = self.layer_norm_in(z)
+        a = mask
+        a = a * self.sigmoid(self.linear_a_g(z))
+        a = a * self.linear_a_p(z)
+        b = mask
+        b = b * self.sigmoid(self.linear_b_g(z))
+        b = b * self.linear_b_p(z)
+
+        device_type = a.device.type if a.device.type != "mps" else "cpu"
+        if is_fp16_enabled(device_type):
+            with torch.autocast(device_type=device_type, enabled=False):
+                x = self._combine_projections(a.float(), b.float())
+        else:
+            x = self._combine_projections(a, b)
+
+        del a, b
+        x = self.layer_norm_out(x)
+        x = self.linear_z(x)
+        g = self.sigmoid(self.linear_g(z))
+        x = x * g
+
+        return x
+
+
+class EsmFoldPreTrainedModel(EsmPreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    # Subclass `EsMPreTrainedModel` to deal with special init
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        if isinstance(module, EsmFoldLinear):
+            with torch.no_grad():
+                if module.init_fn is not None:
+                    module.init_fn(module.weight, module.bias)
+                elif module.init == "default":
+                    trunc_normal_init_(module.weight, scale=1.0)
+                elif module.init == "relu":
+                    trunc_normal_init_(module.weight, scale=2.0)
+                elif module.init == "glorot":
+                    nn.init.xavier_uniform_(module.weight, gain=1)
+                elif module.init == "gating":
+                    module.weight.fill_(0.0)
+                    if module.bias:
+                        module.bias.fill_(1.0)
+                elif module.init == "normal":
+                    torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear")
+                elif module.init == "final":
+                    module.weight.fill_(0.0)
+        elif isinstance(module, EsmFoldInvariantPointAttention):
+            ipa_point_weights_init_(module.head_weights)
+        elif isinstance(module, EsmFoldTriangularSelfAttentionBlock):
+            torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight)
+            torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias)
+            torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight)
+            torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias)
+            torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight)
+            torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias)
+            torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight)
+            torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias)
+
+            torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight)
+            torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias)
+            torch.nn.init.zeros_(module.pair_to_sequence.linear.weight)
+            torch.nn.init.zeros_(module.seq_attention.o_proj.weight)
+            torch.nn.init.zeros_(module.seq_attention.o_proj.bias)
+            torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight)
+            torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias)
+            torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight)
+            torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias)
+        else:
+            super()._init_weights(module)
+
+
+class EsmFoldSelfAttention(nn.Module):
+    def __init__(self, embed_dim, num_heads, head_width, gated=False):
+        super().__init__()
+        assert embed_dim == num_heads * head_width
+
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.head_width = head_width
+
+        self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
+        self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+        self.gated = gated
+        if gated:
+            self.g_proj = nn.Linear(embed_dim, embed_dim)
+            torch.nn.init.zeros_(self.g_proj.weight)
+            torch.nn.init.ones_(self.g_proj.bias)
+
+        self.rescale_factor = self.head_width**-0.5
+
+        torch.nn.init.zeros_(self.o_proj.bias)
+
+    def forward(self, x, mask=None, bias=None, indices=None):
+        """
+        Basic self attention with optional mask and external pairwise bias. To handle sequences of different lengths,
+        use mask.
+
+        Inputs:
+            x: batch of input sequences (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (..
+            x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads)
+
+        Outputs:
+          sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads)
+        """
+
+        t = self.proj(x).view(*x.shape[:2], self.num_heads, -1)
+        t = t.permute(0, 2, 1, 3)
+        q, k, v = t.chunk(3, dim=-1)
+
+        q = self.rescale_factor * q
+        a = torch.einsum("...qc,...kc->...qk", q, k)
+
+        # Add external attention bias.
+        if bias is not None:
+            a = a + bias.permute(0, 3, 1, 2)
+
+        # Do not attend to padding tokens.
+        if mask is not None:
+            mask = mask[:, None, None]
+            a = a.masked_fill(mask == False, -np.inf)  # noqa: E712
+
+        a = nn.functional.softmax(a, dim=-1)
+
+        y = torch.einsum("...hqk,...hkc->...qhc", a, v)
+        y = y.reshape(*y.shape[:2], -1)
+
+        if self.gated:
+            y = self.g_proj(x).sigmoid() * y
+        y = self.o_proj(y)
+
+        return y, a.permute(0, 3, 1, 2)
+
+
+class EsmFoldDropout(nn.Module):
+    """
+    Implementation of dropout with the ability to share the dropout mask along a particular dimension.
+    """
+
+    def __init__(self, r: float, batch_dim: Union[int, list[int]]):
+        super().__init__()
+
+        self.r = r
+        if isinstance(batch_dim, int):
+            batch_dim = [batch_dim]
+        self.batch_dim = batch_dim
+        self.dropout = nn.Dropout(self.r)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        shape = list(x.shape)
+        if self.batch_dim is not None:
+            for bd in self.batch_dim:
+                shape[bd] = 1
+        return x * self.dropout(x.new_ones(shape))
+
+
+class EsmFoldSequenceToPair(nn.Module):
+    def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
+        super().__init__()
+
+        self.layernorm = nn.LayerNorm(sequence_state_dim)
+        self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)
+        self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)
+
+        torch.nn.init.zeros_(self.proj.bias)
+        torch.nn.init.zeros_(self.o_proj.bias)
+
+    def forward(self, sequence_state):
+        """
+        Inputs:
+          sequence_state: B x L x sequence_state_dim
+
+        Output:
+          pairwise_state: B x L x L x pairwise_state_dim
+
+        Intermediate state:
+          B x L x L x 2*inner_dim
+        """
+
+        assert len(sequence_state.shape) == 3
+
+        s = self.layernorm(sequence_state)
+        s = self.proj(s)
+        q, k = s.chunk(2, dim=-1)
+
+        prod = q[:, None, :, :] * k[:, :, None, :]
+        diff = q[:, None, :, :] - k[:, :, None, :]
+
+        x = torch.cat([prod, diff], dim=-1)
+        x = self.o_proj(x)
+
+        return x
+
+
+class EsmFoldPairToSequence(nn.Module):
+    def __init__(self, pairwise_state_dim, num_heads):
+        super().__init__()
+
+        self.layernorm = nn.LayerNorm(pairwise_state_dim)
+        self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)
+
+    def forward(self, pairwise_state):
+        """
+        Inputs:
+          pairwise_state: B x L x L x pairwise_state_dim
+
+        Output:
+          pairwise_bias: B x L x L x num_heads
+        """
+        assert len(pairwise_state.shape) == 4
+        z = self.layernorm(pairwise_state)
+        pairwise_bias = self.linear(z)
+        return pairwise_bias
+
+
+class EsmFoldResidueMLP(nn.Module):
+    def __init__(self, embed_dim, inner_dim, dropout=0):
+        super().__init__()
+
+        self.mlp = nn.Sequential(
+            nn.LayerNorm(embed_dim),
+            nn.Linear(embed_dim, inner_dim),
+            nn.ReLU(),
+            nn.Linear(inner_dim, embed_dim),
+            nn.Dropout(dropout),
+        )
+
+    def forward(self, x):
+        return x + self.mlp(x)
+
+
+class EsmFoldTriangularSelfAttentionBlock(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+
+        sequence_state_dim = config.sequence_state_dim
+        pairwise_state_dim = config.pairwise_state_dim
+        sequence_num_heads = sequence_state_dim // config.sequence_head_width
+        pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width
+
+        self.layernorm_1 = nn.LayerNorm(sequence_state_dim)
+
+        self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim)
+        self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads)
+
+        self.seq_attention = EsmFoldSelfAttention(
+            sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True
+        )
+        self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True)
+        self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False)
+
+        self.tri_att_start = EsmFoldTriangleAttention(
+            pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True
+        )
+        self.tri_att_end = EsmFoldTriangleAttention(
+            pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False
+        )
+
+        self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout)
+        self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout)
+
+        self.drop = nn.Dropout(config.dropout)
+        self.row_drop = EsmFoldDropout(config.dropout * 2, 2)
+        self.col_drop = EsmFoldDropout(config.dropout * 2, 1)
+
+    def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs):
+        """
+        Inputs:
+          sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean
+          tensor of valid positions
+
+        Output:
+          sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim
+        """
+        if len(sequence_state.shape) != 3:
+            raise ValueError(f"`sequence_state` should be a 3d-tensor, got {len(sequence_state.shape)} dims.")
+        if len(pairwise_state.shape) != 4:
+            raise ValueError(f"`pairwise_state` should be a 4d-tensor, got {len(pairwise_state.shape)} dims.")
+        if mask is not None and len(mask.shape) != 2:
+            raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
+
+        batch_dim, seq_dim, sequence_state_dim = sequence_state.shape
+        pairwise_state_dim = pairwise_state.shape[3]
+
+        if sequence_state_dim != self.config.sequence_state_dim:
+            raise ValueError(
+                "`sequence_state` last dimension should be equal to `self.sequence_state_dim`. Got "
+                f"{sequence_state_dim} != {self.config.sequence_state_dim}."
+            )
+        if pairwise_state_dim != self.config.pairwise_state_dim:
+            raise ValueError(
+                "`pairwise_state` last dimension should be equal to `self.pairwise_state_dim`. Got "
+                f"{pairwise_state_dim} != {self.config.pairwise_state_dim}."
+            )
+        if batch_dim != pairwise_state.shape[0]:
+            raise ValueError(
+                f"`sequence_state` and `pairwise_state` have inconsistent batch size: {batch_dim} != "
+                f"{pairwise_state.shape[0]}."
+            )
+        if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]:
+            raise ValueError(
+                f"`sequence_state` and `pairwise_state` have inconsistent sequence length: {seq_dim} != "
+                f"{pairwise_state.shape[1]} or {pairwise_state.shape[2]}."
+            )
+
+        # Update sequence state
+        bias = self.pair_to_sequence(pairwise_state)
+
+        # Self attention with bias + mlp.
+        y = self.layernorm_1(sequence_state)
+        y, _ = self.seq_attention(y, mask=mask, bias=bias)
+        sequence_state = sequence_state + self.drop(y)
+        sequence_state = self.mlp_seq(sequence_state)
+
+        # Update pairwise state
+        pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state)
+
+        # Axial attention with triangular bias.
+        tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None
+        pairwise_state = pairwise_state + self.row_drop(self.tri_mul_out(pairwise_state, mask=tri_mask))
+        pairwise_state = pairwise_state + self.col_drop(self.tri_mul_in(pairwise_state, mask=tri_mask))
+        pairwise_state = pairwise_state + self.row_drop(
+            self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
+        )
+        pairwise_state = pairwise_state + self.col_drop(
+            self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
+        )
+
+        # MLP over pairs.
+        pairwise_state = self.mlp_pair(pairwise_state)
+
+        return sequence_state, pairwise_state
+
+
+class EsmCategoricalMixture:
+    def __init__(self, param, bins=50, start=0, end=1):
+        # All tensors are of shape ..., bins.
+        self.logits = param
+        bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype)
+        self.v_bins = (bins[:-1] + bins[1:]) / 2
+
+    def log_prob(self, true):
+        # Shapes are:
+        #     self.probs: ... x bins
+        #     true      : ...
+        true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1)
+        nll = self.logits.log_softmax(-1)
+        return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
+
+    def mean(self):
+        return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)
+
+
+def categorical_lddt(logits, bins=50):
+    # Logits are ..., 37, bins.
+    return EsmCategoricalMixture(logits, bins=bins).mean()
+
+
+def get_axial_mask(mask):
+    """
+    Helper to convert B x L mask of valid positions to axial mask used in row column attentions.
+
+    Input:
+      mask: B x L tensor of booleans
+
+    Output:
+      mask: B x L x L tensor of booleans
+    """
+
+    if mask is None:
+        return None
+
+    if len(mask.shape) != 2:
+        raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
+    batch_dim, seq_dim = mask.shape
+    m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim)
+    m = m.reshape(batch_dim * seq_dim, seq_dim)
+    return m
+
+
+class EsmFoldRelativePosition(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.bins = config.position_bins
+
+        # Note an additional offset is used so that the 0th position
+        # is reserved for masked pairs.
+        self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim)
+
+    def forward(self, residue_index, mask=None):
+        """
+        Input:
+          residue_index: B x L tensor of indices (dtype=torch.long) mask: B x L tensor of booleans
+
+        Output:
+          pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings
+        """
+        if residue_index.dtype != torch.long:
+            raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.")
+        if mask is not None and residue_index.shape != mask.shape:
+            raise ValueError(
+                f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}."
+            )
+
+        diff = residue_index[:, None, :] - residue_index[:, :, None]
+        diff = diff.clamp(-self.bins, self.bins)
+        diff = diff + self.bins + 1  # Add 1 to adjust for padding index.
+
+        if mask is not None:
+            mask = mask[:, None, :] * mask[:, :, None]
+            diff[mask == False] = 0  # noqa: E712
+
+        output = self.embedding(diff)
+        return output
+
+
+class EsmFoldAngleResnetBlock(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+
+        self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu")
+        self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final")
+
+        self.relu = nn.ReLU()
+
+    def forward(self, a: torch.Tensor) -> torch.Tensor:
+        s_initial = a
+
+        a = self.relu(a)
+        a = self.linear_1(a)
+        a = self.relu(a)
+        a = self.linear_2(a)
+
+        return a + s_initial
+
+
+class EsmFoldAngleResnet(nn.Module):
+    """
+    Implements Algorithm 20, lines 11-14
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+
+        self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
+        self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
+
+        self.layers = nn.ModuleList()
+        for _ in range(config.num_resnet_blocks):
+            layer = EsmFoldAngleResnetBlock(config)
+            self.layers.append(layer)
+
+        self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2)
+
+        self.relu = nn.ReLU()
+
+    def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+        """
+        Args:
+            s:
+                [*, C_hidden] single embedding
+            s_initial:
+                [*, C_hidden] single embedding as of the start of the StructureModule
+        Returns:
+            [*, no_angles, 2] predicted angles
+        """
+        # NOTE: The ReLU's applied to the inputs are absent from the supplement
+        # pseudocode but present in the source. For maximal compatibility with
+        # the pretrained weights, I'm going with the source.
+
+        # [*, C_hidden]
+        s_initial = self.relu(s_initial)
+        s_initial = self.linear_initial(s_initial)
+        s = self.relu(s)
+        s = self.linear_in(s)
+        s = s + s_initial
+
+        for l in self.layers:
+            s = l(s)
+
+        s = self.relu(s)
+
+        # [*, no_angles * 2]
+        s = self.linear_out(s)
+
+        # [*, no_angles, 2]
+        s = s.view(s.shape[:-1] + (-1, 2))
+
+        unnormalized_s = s
+        norm_denom = torch.sqrt(
+            torch.clamp(
+                torch.sum(s**2, dim=-1, keepdim=True),
+                min=self.config.epsilon,
+            )
+        )
+        s = s / norm_denom
+
+        return unnormalized_s, s
+
+
+class EsmFoldInvariantPointAttention(nn.Module):
+    """
+    Implements Algorithm 22.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+
+        c_s = config.sequence_dim
+        c_z = config.pairwise_dim
+        self.hidden_dim = config.ipa_dim
+        self.num_heads = config.num_heads_ipa
+        self.num_qk_points = config.num_qk_points
+        self.num_v_points = config.num_v_points
+
+        # These linear layers differ from their specifications in the
+        # supplement. There, they lack bias and use Glorot initialization.
+        # Here as in the official source, they have bias and use the default
+        # Lecun initialization.
+        hc = config.ipa_dim * config.num_heads_ipa
+        self.linear_q = EsmFoldLinear(c_s, hc)
+        self.linear_kv = EsmFoldLinear(c_s, 2 * hc)
+
+        hpq = config.num_heads_ipa * config.num_qk_points * 3
+        self.linear_q_points = EsmFoldLinear(c_s, hpq)
+
+        hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3
+        self.linear_kv_points = EsmFoldLinear(c_s, hpkv)
+
+        self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa)
+
+        self.head_weights = nn.Parameter(torch.zeros(config.num_heads_ipa))
+
+        concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4)
+        self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final")
+
+        self.softmax = nn.Softmax(dim=-1)
+        self.softplus = nn.Softplus()
+
+    def forward(
+        self,
+        s: torch.Tensor,
+        z: Optional[torch.Tensor],
+        r: Rigid,
+        mask: torch.Tensor,
+        _offload_inference: bool = False,
+        _z_reference_list: Optional[Sequence[torch.Tensor]] = None,
+    ) -> torch.Tensor:
+        """
+        Args:
+            s:
+                [*, N_res, C_s] single representation
+            z:
+                [*, N_res, N_res, C_z] pair representation
+            r:
+                [*, N_res] transformation object
+            mask:
+                [*, N_res] mask
+        Returns:
+            [*, N_res, C_s] single representation update
+        """
+        z = [z]
+
+        #######################################
+        # Generate scalar and point activations
+        #######################################
+        # [*, N_res, H * C_hidden]
+        q = self.linear_q(s)
+        kv = self.linear_kv(s)
+
+        # [*, N_res, H, C_hidden]
+        q = q.view(q.shape[:-1] + (self.num_heads, -1))
+
+        # [*, N_res, H, 2 * C_hidden]
+        kv = kv.view(kv.shape[:-1] + (self.num_heads, -1))
+
+        # [*, N_res, H, C_hidden]
+        k, v = torch.split(kv, self.hidden_dim, dim=-1)
+
+        # [*, N_res, H * P_q * 3]
+        q_pts = self.linear_q_points(s)
+
+        # This is kind of clunky, but it's how the original does it
+        # [*, N_res, H * P_q, 3]
+        q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
+        q_pts = torch.stack(q_pts, dim=-1)
+        q_pts = r[..., None].apply(q_pts)
+
+        # [*, N_res, H, P_q, 3]
+        q_pts = q_pts.view(q_pts.shape[:-2] + (self.num_heads, self.num_qk_points, 3))
+
+        # [*, N_res, H * (P_q + P_v) * 3]
+        kv_pts = self.linear_kv_points(s)
+
+        # [*, N_res, H * (P_q + P_v), 3]
+        kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
+        kv_pts = torch.stack(kv_pts, dim=-1)
+        kv_pts = r[..., None].apply(kv_pts)
+
+        # [*, N_res, H, (P_q + P_v), 3]
+        kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3))
+
+        # [*, N_res, H, P_q/P_v, 3]
+        k_pts, v_pts = torch.split(kv_pts, [self.num_qk_points, self.num_v_points], dim=-2)
+
+        ##########################
+        # Compute attention scores
+        ##########################
+        # [*, N_res, N_res, H]
+        b = self.linear_b(z[0])
+
+        if _offload_inference:
+            assert sys.getrefcount(z[0]) == 2
+            z[0] = z[0].cpu()
+
+        # [*, H, N_res, N_res]
+        device_type = q.device.type if q.device.type != "mps" else "cpu"
+        if is_fp16_enabled(device_type):
+            with torch.autocast(device_type=device_type, enabled=False):
+                a = torch.matmul(
+                    permute_final_dims(q.float(), (1, 0, 2)),  # [*, H, N_res, C_hidden]
+                    permute_final_dims(k.float(), (1, 2, 0)),  # [*, H, C_hidden, N_res]
+                )
+        else:
+            a = torch.matmul(
+                permute_final_dims(q, (1, 0, 2)),  # [*, H, N_res, C_hidden]
+                permute_final_dims(k, (1, 2, 0)),  # [*, H, C_hidden, N_res]
+            )
+
+        a *= math.sqrt(1.0 / (3 * self.hidden_dim))
+        a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))
+
+        # [*, N_res, N_res, H, P_q, 3]
+        pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
+        pt_att = pt_att**2
+
+        # [*, N_res, N_res, H, P_q]
+        pt_att = sum(torch.unbind(pt_att, dim=-1))
+        head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1)))
+        head_weights = head_weights * math.sqrt(1.0 / (3 * (self.num_qk_points * 9.0 / 2)))
+        pt_att = pt_att * head_weights
+
+        # [*, N_res, N_res, H]
+        pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
+        # [*, N_res, N_res]
+        square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
+        square_mask = self.config.inf * (square_mask - 1)
+
+        # [*, H, N_res, N_res]
+        pt_att = permute_final_dims(pt_att, (2, 0, 1))
+
+        a = a + pt_att
+        a = a + square_mask.unsqueeze(-3)
+        a = self.softmax(a)
+
+        ################
+        # Compute output
+        ################
+        # [*, N_res, H, C_hidden]
+        o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3)
+
+        # [*, N_res, H * C_hidden]
+        o = flatten_final_dims(o, 2)
+
+        # [*, H, 3, N_res, P_v]
+        o_pt = torch.sum(
+            (a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]),
+            dim=-2,
+        )
+
+        # [*, N_res, H, P_v, 3]
+        o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
+        o_pt = r[..., None, None].invert_apply(o_pt)
+
+        # [*, N_res, H * P_v]
+        o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.config.epsilon), 2)
+
+        # [*, N_res, H * P_v, 3]
+        o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
+
+        if _offload_inference:
+            z[0] = z[0].to(o_pt.device)
+
+        # [*, N_res, H, C_z]
+        o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
+
+        # [*, N_res, H * C_z]
+        o_pair = flatten_final_dims(o_pair, 2)
+
+        # [*, N_res, C_s]
+        s = self.linear_out(
+            torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype)
+        )
+
+        return s
+
+
+class EsmFoldBackboneUpdate(nn.Module):
+    """
+    Implements part of Algorithm 23.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+
+        self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final")
+
+    def forward(self, s: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+        """
+        Args:
+            [*, N_res, C_s] single representation
+        Returns:
+            [*, N_res, 6] update vector
+        """
+        # [*, 6]
+        update = self.linear(s)
+
+        return update
+
+
+class EsmFoldStructureModuleTransitionLayer(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+
+        self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
+        self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
+        self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final")
+
+        self.relu = nn.ReLU()
+
+    def forward(self, s):
+        s_initial = s
+        s = self.linear_1(s)
+        s = self.relu(s)
+        s = self.linear_2(s)
+        s = self.relu(s)
+        s = self.linear_3(s)
+
+        s = s + s_initial
+
+        return s
+
+
+class EsmFoldStructureModuleTransition(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+
+        self.layers = nn.ModuleList()
+        for _ in range(config.num_transition_layers):
+            l = EsmFoldStructureModuleTransitionLayer(config)
+            self.layers.append(l)
+
+        self.dropout = nn.Dropout(config.dropout_rate)
+        self.layer_norm = LayerNorm(config.sequence_dim)
+
+    def forward(self, s):
+        for l in self.layers:
+            s = l(s)
+
+        s = self.dropout(s)
+        s = self.layer_norm(s)
+
+        return s
+
+
+class EsmFoldStructureModule(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+
+        # Buffers to be lazily initialized later
+        # self.default_frames
+        # self.group_idx
+        # self.atom_mask
+        # self.lit_positions
+
+        self.layer_norm_s = LayerNorm(config.sequence_dim)
+        self.layer_norm_z = LayerNorm(config.pairwise_dim)
+
+        self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim)
+
+        self.ipa = EsmFoldInvariantPointAttention(config)
+
+        self.ipa_dropout = nn.Dropout(config.dropout_rate)
+        self.layer_norm_ipa = LayerNorm(config.sequence_dim)
+
+        self.transition = EsmFoldStructureModuleTransition(config)
+        self.bb_update = EsmFoldBackboneUpdate(config)
+        self.angle_resnet = EsmFoldAngleResnet(config)
+
+    def forward(
+        self,
+        evoformer_output_dict,
+        aatype,
+        mask=None,
+        _offload_inference=False,
+    ):
+        """
+        Args:
+            evoformer_output_dict:
+                Dictionary containing:
+                    "single":
+                        [*, N_res, C_s] single representation
+                    "pair":
+                        [*, N_res, N_res, C_z] pair representation
+            aatype:
+                [*, N_res] amino acid indices
+            mask:
+                Optional [*, N_res] sequence mask
+        Returns:
+            A dictionary of outputs
+        """
+        s = evoformer_output_dict["single"]
+
+        if mask is None:
+            # [*, N]
+            mask = s.new_ones(s.shape[:-1])
+
+        # [*, N, C_s]
+        s = self.layer_norm_s(s)
+
+        # [*, N, N, C_z]
+        z = self.layer_norm_z(evoformer_output_dict["pair"])
+
+        z_reference_list = None
+        if _offload_inference:
+            assert sys.getrefcount(evoformer_output_dict["pair"]) == 2
+            evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
+            z_reference_list = [z]
+            z = None
+
+        # [*, N, C_s]
+        s_initial = s
+        s = self.linear_in(s)
+
+        # [*, N]
+        rigids = Rigid.identity(
+            s.shape[:-1],
+            s.dtype,
+            s.device,
+            self.training,
+            fmt="quat",
+        )
+        outputs = []
+        for i in range(self.config.num_blocks):
+            # [*, N, C_s]
+            s = s + self.ipa(
+                s,
+                z,
+                rigids,
+                mask,
+                _offload_inference=_offload_inference,
+                _z_reference_list=z_reference_list,
+            )
+            s = self.ipa_dropout(s)
+            s = self.layer_norm_ipa(s)
+            s = self.transition(s)
+
+            # [*, N]
+            rigids = rigids.compose_q_update_vec(self.bb_update(s))
+
+            # To hew as closely as possible to AlphaFold, we convert our
+            # quaternion-based transformations to rotation-matrix ones
+            # here
+            backb_to_global = Rigid(
+                Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None),
+                rigids.get_trans(),
+            )
+
+            backb_to_global = backb_to_global.scale_translation(self.config.trans_scale_factor)
+
+            # [*, N, 7, 2]
+            unnormalized_angles, angles = self.angle_resnet(s, s_initial)
+
+            all_frames_to_global = self.torsion_angles_to_frames(backb_to_global, angles, aatype)
+
+            pred_xyz = self.frames_and_literature_positions_to_atom14_pos(all_frames_to_global, aatype)
+
+            scaled_rigids = rigids.scale_translation(self.config.trans_scale_factor)
+
+            preds = {
+                "frames": scaled_rigids.to_tensor_7(),
+                "sidechain_frames": all_frames_to_global.to_tensor_4x4(),
+                "unnormalized_angles": unnormalized_angles,
+                "angles": angles,
+                "positions": pred_xyz,
+                "states": s,
+            }
+
+            outputs.append(preds)
+
+            rigids = rigids.stop_rot_gradient()
+
+        del z, z_reference_list
+
+        if _offload_inference:
+            evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device)
+
+        outputs = dict_multimap(torch.stack, outputs)
+        outputs["single"] = s
+
+        return outputs
+
+    def _init_residue_constants(self, float_dtype, device):
+        if not hasattr(self, "default_frames"):
+            self.register_buffer(
+                "default_frames",
+                torch.tensor(
+                    residue_constants.restype_rigid_group_default_frame,
+                    dtype=float_dtype,
+                    device=device,
+                    requires_grad=False,
+                ),
+                persistent=False,
+            )
+        if not hasattr(self, "group_idx"):
+            self.register_buffer(
+                "group_idx",
+                torch.tensor(
+                    residue_constants.restype_atom14_to_rigid_group,
+                    device=device,
+                    requires_grad=False,
+                ),
+                persistent=False,
+            )
+        if not hasattr(self, "atom_mask"):
+            self.register_buffer(
+                "atom_mask",
+                torch.tensor(
+                    residue_constants.restype_atom14_mask,
+                    dtype=float_dtype,
+                    device=device,
+                    requires_grad=False,
+                ),
+                persistent=False,
+            )
+        if not hasattr(self, "lit_positions"):
+            self.register_buffer(
+                "lit_positions",
+                torch.tensor(
+                    residue_constants.restype_atom14_rigid_group_positions,
+                    dtype=float_dtype,
+                    device=device,
+                    requires_grad=False,
+                ),
+                persistent=False,
+            )
+
+    def torsion_angles_to_frames(self, r, alpha, f):
+        # Lazily initialize the residue constants on the correct device
+        self._init_residue_constants(alpha.dtype, alpha.device)
+        # Separated purely to make testing less annoying
+        return torsion_angles_to_frames(r, alpha, f, self.default_frames)
+
+    def frames_and_literature_positions_to_atom14_pos(self, r, f):  # [*, N, 8]  # [*, N]
+        # Lazily initialize the residue constants on the correct device
+        self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
+        return frames_and_literature_positions_to_atom14_pos(
+            r,
+            f,
+            self.default_frames,
+            self.group_idx,
+            self.atom_mask,
+            self.lit_positions,
+        )
+
+
+class EsmFoldingTrunk(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+
+        c_s = config.sequence_state_dim
+        c_z = config.pairwise_state_dim
+
+        self.pairwise_positional_embedding = EsmFoldRelativePosition(config)
+
+        self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)])
+
+        self.recycle_bins = 15
+        self.recycle_s_norm = nn.LayerNorm(c_s)
+        self.recycle_z_norm = nn.LayerNorm(c_z)
+        self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
+        self.recycle_disto.weight[0].detach().zero_()
+
+        self.structure_module = EsmFoldStructureModule(config.structure_module)
+        self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim)
+        self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim)
+
+        self.chunk_size = config.chunk_size
+
+    def set_chunk_size(self, chunk_size):
+        # This parameter means the axial attention will be computed
+        # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
+        # It's equivalent to running a for loop over chunks of the dimension we're iterative over,
+        # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-length chunks.
+        self.chunk_size = chunk_size
+
+    def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles):
+        """
+        Inputs:
+          seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B
+          x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues
+
+        Output:
+          predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
+        """
+
+        device = seq_feats.device
+        s_s_0 = seq_feats
+        s_z_0 = pair_feats
+
+        if no_recycles is None:
+            no_recycles = self.config.max_recycles
+        else:
+            if no_recycles < 0:
+                raise ValueError("Number of recycles must not be negative.")
+            no_recycles += 1  # First 'recycle' is just the standard forward pass through the model.
+
+        def trunk_iter(s, z, residx, mask):
+            z = z + self.pairwise_positional_embedding(residx, mask=mask)
+
+            for block in self.blocks:
+                s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
+            return s, z
+
+        s_s = s_s_0
+        s_z = s_z_0
+        recycle_s = torch.zeros_like(s_s)
+        recycle_z = torch.zeros_like(s_z)
+        recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)
+
+        for recycle_idx in range(no_recycles):
+            with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
+                # === Recycling ===
+                recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
+                recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
+                recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)
+
+                s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
+
+                # === Structure module ===
+                structure = self.structure_module(
+                    {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
+                    true_aa,
+                    mask.float(),
+                )
+
+                recycle_s = s_s
+                recycle_z = s_z
+                # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
+                recycle_bins = EsmFoldingTrunk.distogram(
+                    structure["positions"][-1][:, :, :3],
+                    3.375,
+                    21.375,
+                    self.recycle_bins,
+                )
+
+        structure["s_s"] = s_s
+        structure["s_z"] = s_z
+
+        return structure
+
+    @staticmethod
+    def distogram(coords, min_bin, max_bin, num_bins):
+        # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates.
+        boundaries = torch.linspace(
+            min_bin,
+            max_bin,
+            num_bins - 1,
+            device=coords.device,
+        )
+        boundaries = boundaries**2
+        N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
+        # Infer CB coordinates.
+        b = CA - N
+        c = C - CA
+        a = b.cross(c, dim=-1)
+        CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
+        dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
+        bins = torch.sum(dists > boundaries, dim=-1)  # [..., L, L]
+        return bins
+
+
+# TODO Add information to the docstring about any methods that convert to PDB format, or otherwise prepare
+#      the outputs for downstream use.
+
+
+@auto_docstring(
+    custom_intro="""
+    ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed
+    by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to
+    the rest of the model combined! It outputs a dictionary containing predicted structural information about the input
+    protein(s).
+    """
+)
+class EsmForProteinFolding(EsmPreTrainedModel):
+    _no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"]
+    _supports_flash_attn = False
+    _supports_sdpa = False
+    _supports_attention_backend = False
+
+    _can_record_outputs = None
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.config = config
+
+        self.distogram_bins = 64
+
+        self.esm = EsmModel(config, add_pooling_layer=False)
+
+        self.esm.requires_grad_(False)
+        if self.config.esmfold_config.fp16_esm:
+            self.esm.half()
+
+        self.esm_feats = self.config.hidden_size
+        self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads
+        self.esm_layers = self.config.num_hidden_layers
+        self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list))
+        self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1))
+
+        trunk_config = self.config.esmfold_config.trunk
+        c_s = trunk_config.sequence_state_dim
+        c_z = trunk_config.pairwise_state_dim
+        self.esm_s_mlp = nn.Sequential(
+            LayerNorm(self.esm_feats),
+            nn.Linear(self.esm_feats, c_s),
+            nn.ReLU(),
+            nn.Linear(c_s, c_s),
+        )
+
+        # 0 is padding, N is unknown residues, N + 1 is mask.
+        self.n_tokens_embed = residue_constants.restype_num + 3
+        self.pad_idx = 0
+        self.unk_idx = self.n_tokens_embed - 2
+        self.mask_idx = self.n_tokens_embed - 1
+        self.esm_dict_cls_idx = self.config.vocab_list.index("")
+        self.esm_dict_mask_idx = self.config.vocab_list.index("")
+        self.esm_dict_eos_idx = self.config.vocab_list.index("")
+        self.esm_dict_padding_idx = self.config.vocab_list.index("")
+        if self.config.esmfold_config.embed_aa:
+            self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)
+
+        self.trunk = EsmFoldingTrunk(trunk_config)
+
+        self.distogram_head = nn.Linear(c_z, self.distogram_bins)
+        self.ptm_head = nn.Linear(c_z, self.distogram_bins)
+        self.lm_head = nn.Linear(c_s, self.n_tokens_embed)
+        self.lddt_bins = 50
+        structure_module_config = trunk_config.structure_module
+        self.lddt_head = nn.Sequential(
+            nn.LayerNorm(structure_module_config.sequence_dim),
+            nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim),
+            nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim),
+            nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins),
+        )
+
+    @staticmethod
+    def _af2_to_esm_from_vocab_list(vocab_list: list[str]) -> torch.Tensor:
+        # Remember that t is shifted from residue_constants by 1 (0 is padding).
+        esm_reorder = [vocab_list.index("")] + [vocab_list.index(v) for v in residue_constants.restypes_with_x]
+        return torch.tensor(esm_reorder)
+
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.Tensor] = None,
+        masking_pattern: Optional[torch.Tensor] = None,
+        num_recycles: Optional[int] = None,
+        output_hidden_states: Optional[bool] = False,
+    ) -> EsmForProteinFoldingOutput:
+        r"""
+        masking_pattern (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`.
+        num_recycles (`int`, *optional*, defaults to `None`):
+            Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling"
+            consists of passing the output of the folding trunk back in as input to the trunk. During training, the
+            number of recycles should vary with each batch, to ensure that the model learns to output valid predictions
+            after each recycle. During inference, num_recycles should be set to the highest value that the model was
+            trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is
+            used.
+
+        Example:
+
+        ```python
+        >>> from transformers import AutoTokenizer, EsmForProteinFolding
+
+        >>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
+        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
+        >>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt", add_special_tokens=False)  # A tiny random peptide
+        >>> outputs = model(**inputs)
+        >>> folded_positions = outputs.positions
+        ```
+
+        """
+        cfg = self.config.esmfold_config
+
+        aa = input_ids  # B x L
+        B = aa.shape[0]
+        L = aa.shape[1]
+        device = input_ids.device
+        if attention_mask is None:
+            attention_mask = torch.ones_like(aa, device=device)
+        if position_ids is None:
+            position_ids = torch.arange(L, device=device).expand_as(input_ids)
+
+        # === ESM ===
+        esmaa = self.af2_idx_to_esm_idx(aa, attention_mask)
+
+        if masking_pattern is not None:
+            masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern)
+        else:
+            masked_aa = aa
+            mlm_targets = None
+
+        # We get sequence and pair representations from whatever version of ESM /
+        # configuration we are using. The sequence representation esm_s is always
+        # present. The pair embedding esm_z may be present depending on the
+        # configuration of the model. If esm_z is not used by the model then it
+        # is returned as None here.
+        esm_s = self.compute_language_model_representations(esmaa)
+
+        # Convert esm_s and esm_z, if present, to the precision used by the trunk and
+        # the structure module. These tensors may be a lower precision if, for example,
+        # we're running the language model in fp16 precision.
+        esm_s = esm_s.to(self.esm_s_combine.dtype)
+
+        if cfg.esm_ablate_sequence:
+            esm_s = esm_s * 0
+
+        esm_s = esm_s.detach()
+
+        # === preprocessing ===
+        esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)
+        s_s_0 = self.esm_s_mlp(esm_s)
+
+        s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim)
+
+        if self.config.esmfold_config.embed_aa:
+            s_s_0 += self.embedding(masked_aa)
+
+        structure: dict = self.trunk(s_s_0, s_z_0, aa, position_ids, attention_mask, no_recycles=num_recycles)
+        # Documenting what we expect:
+        structure = {
+            k: v
+            for k, v in structure.items()
+            if k
+            in [
+                "s_z",
+                "s_s",
+                "frames",
+                "sidechain_frames",
+                "unnormalized_angles",
+                "angles",
+                "positions",
+                "states",
+            ]
+        }
+
+        # Add BERT mask for the loss to use, if available.
+        if mlm_targets:
+            structure["mlm_targets"] = mlm_targets
+
+        disto_logits = self.distogram_head(structure["s_z"])
+        disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
+        structure["distogram_logits"] = disto_logits
+
+        lm_logits = self.lm_head(structure["s_s"])
+        structure["lm_logits"] = lm_logits
+
+        structure["aatype"] = aa
+        make_atom14_masks(structure)
+        # Of course, this doesn't respect the true mask because it doesn't know about it...
+        # We're not going to properly mask change of index tensors:
+        #    "residx_atom14_to_atom37",
+        #    "residx_atom37_to_atom14",
+        for k in [
+            "atom14_atom_exists",
+            "atom37_atom_exists",
+        ]:
+            structure[k] *= attention_mask.unsqueeze(-1)
+        structure["residue_index"] = position_ids
+
+        lddt_head = self.lddt_head(structure["states"]).reshape(structure["states"].shape[0], B, L, -1, self.lddt_bins)
+        structure["lddt_head"] = lddt_head
+        plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins)
+        structure["plddt"] = plddt
+
+        ptm_logits = self.ptm_head(structure["s_z"])
+        structure["ptm_logits"] = ptm_logits
+        structure["ptm"] = compute_tm(ptm_logits, max_bin=31, no_bins=self.distogram_bins)
+        structure.update(compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins))
+
+        return EsmForProteinFoldingOutput(**structure)
+
+    def af2_idx_to_esm_idx(self, aa, mask):
+        # avoid indexing on different devices
+        if self.af2_to_esm.device != aa.device:
+            self.af2_to_esm = self.af2_to_esm.to(aa.device)
+        aa = (aa + 1).masked_fill(mask != 1, 0)
+        return self.af2_to_esm[aa]
+
+    def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor:
+        device = next(self.parameters()).device
+        B, L = esmaa.shape  # B = batch size, L = sequence length.
+
+        if self.config.esmfold_config.bypass_lm:
+            esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device)
+            return esm_s
+
+        bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx
+        bos = esmaa.new_full((B, 1), bosi)
+        eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx)
+        esmaa = torch.cat([bos, esmaa, eos], dim=1)
+        # Use the first padding index as eos during inference.
+        esmaa[range(B), (esmaa != 1).sum(1)] = eosi
+
+        # _, esm_z, esm_s = self.esm(esmaa, return_pairs=self.config.esmfold_config.use_esm_attn_map)
+        # Because we do not support use_esm_attn_map in the HF port as it is not used in any public models,
+        # esm_z is always None
+        esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"]
+        esm_s = torch.stack(esm_hidden_states, dim=2)
+
+        esm_s = esm_s[:, 1:-1]  # B, L, nLayers, C
+
+        return esm_s
+
+    def bert_mask(self, aa, esmaa, mask, pattern):
+        new_aa = aa.clone()
+        target = aa.clone()
+        new_esmaa = esmaa.clone()
+        new_aa[pattern == 1] = self.mask_idx
+        target[pattern != 1] = 0
+        new_esmaa[pattern == 1] = self.esm_dict_mask_idx
+        return new_aa, new_esmaa, target
+
+    @torch.no_grad()
+    def infer(
+        self,
+        seqs: Union[str, list[str]],
+        position_ids=None,
+    ):
+        if isinstance(seqs, str):
+            lst = [seqs]
+        else:
+            lst = seqs
+        # Returns the raw outputs of the model given an input sequence.
+        device = next(self.parameters()).device
+        aatype = collate_dense_tensors(
+            [
+                torch.from_numpy(
+                    residue_constants.sequence_to_onehot(
+                        sequence=seq,
+                        mapping=residue_constants.restype_order_with_x,
+                        map_unknown_to_x=True,
+                    )
+                )
+                .to(device)
+                .argmax(dim=1)
+                for seq in lst
+            ]
+        )  # B=1 x L
+        mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
+        position_ids = (
+            torch.arange(aatype.shape[1], device=device).expand(len(lst), -1)
+            if position_ids is None
+            else position_ids.to(device)
+        )
+        if position_ids.ndim == 1:
+            position_ids = position_ids.unsqueeze(0)
+        return self.forward(
+            aatype,
+            mask,
+            position_ids=position_ids,
+        )
+
+    @staticmethod
+    def output_to_pdb(output: dict) -> list[str]:
+        """Returns the pbd (file) string from the model given the model output."""
+        output = {k: v.to("cpu").numpy() for k, v in output.items()}
+        pdbs = []
+        final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
+        final_atom_mask = output["atom37_atom_exists"]
+        for i in range(output["aatype"].shape[0]):
+            aa = output["aatype"][i]
+            pred_pos = final_atom_positions[i]
+            mask = final_atom_mask[i]
+            resid = output["residue_index"][i] + 1
+            pred = OFProtein(
+                aatype=aa,
+                atom_positions=pred_pos,
+                atom_mask=mask,
+                residue_index=resid,
+                b_factors=output["plddt"][i],
+            )
+            pdbs.append(to_pdb(pred))
+        return pdbs
+
+    def infer_pdb(self, seqs, *args, **kwargs) -> str:
+        """Returns the pdb (file) string from the model given an input sequence."""
+        assert isinstance(seqs, str)
+        output = self.infer(seqs, *args, **kwargs)
+        return self.output_to_pdb(output)[0]
+
+    def infer_pdbs(self, seqs: list[str], *args, **kwargs) -> list[str]:
+        """Returns the pdb (file) string from the model given an input sequence."""
+        output = self.infer(seqs, *args, **kwargs)
+        return self.output_to_pdb(output)
+
+
+__all__ = ["EsmForProteinFolding", "EsmFoldPreTrainedModel"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_tf_esm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_tf_esm.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fd066868f0e86cd2e130ad170eb38c4791f4f88
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/modeling_tf_esm.py
@@ -0,0 +1,1574 @@
+# coding=utf-8
+# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ESM model."""
+
+from __future__ import annotations
+
+import os
+
+import numpy as np
+import tensorflow as tf
+
+from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
+from ...modeling_tf_outputs import (
+    TFBaseModelOutputWithPastAndCrossAttentions,
+    TFBaseModelOutputWithPoolingAndCrossAttentions,
+    TFMaskedLMOutput,
+    TFSequenceClassifierOutput,
+    TFTokenClassifierOutput,
+)
+from ...modeling_tf_utils import (
+    TFMaskedLanguageModelingLoss,
+    TFModelInputType,
+    TFPreTrainedModel,
+    TFSequenceClassificationLoss,
+    TFTokenClassificationLoss,
+    get_initializer,
+    keras,
+    shape_list,
+    unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, stable_softmax
+from ...utils import logging
+from .configuration_esm import EsmConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/esm2_t6_8M_UR50D"
+_CONFIG_FOR_DOC = "EsmConfig"
+
+
+def rotate_half(x):
+    x1, x2 = tf.split(x, 2, axis=-1)
+    return tf.concat((-x2, x1), axis=-1)
+
+
+def apply_rotary_pos_emb(x, cos, sin):
+    cos = cos[:, :, : tf.shape(x)[-2], :]
+    sin = sin[:, :, : tf.shape(x)[-2], :]
+
+    return (x * cos) + (rotate_half(x) * sin)
+
+
+def symmetrize(x):
+    "Make layer symmetric in final two dimensions, used for contact prediction."
+    return x + tf.linalg.matrix_transpose(x)  # Transposes last two dimensions only
+
+
+def average_product_correct(x):
+    "Perform average product correct, used for contact prediction."
+    a1 = tf.reduce_sum(x, -1, keepdims=True)
+    a2 = tf.reduce_sum(x, -2, keepdims=True)
+    a12 = tf.reduce_sum(x, (-1, -2), keepdims=True)
+
+    avg = a1 * a2
+    avg = avg / a12
+    normalized = x - avg
+    return normalized
+
+
+class TFRotaryEmbedding(keras.layers.Layer):
+    """
+    Rotary position embeddings based on those in
+    [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
+    matrices which depend on their relative positions.
+    """
+
+    def __init__(self, dim: int, name=None):
+        super().__init__(name=name)
+        # Matt: The PyTorch version of this layer does a lot of work to cache values, but we just rely on TF compilation
+        # and/or XLA to sort out constants like that. It actually may not seem like this layer needs to be stateful at
+        # all when we benefit from TF compilation, but it does. The reason is that self.inv_freq is a buffer in the
+        # original implementation, but all the shared ESM checkpoints were trained with fp16 params. This means that
+        # the inv_freq tensor was stored as a float16, and we need to replicate those lower-precision values or our
+        # models give different outputs from the original.
+        self.dim = dim
+
+    def build(self, input_shape):
+        super().build(input_shape)
+        self.inv_freq = self.add_weight(
+            "inv_freq", shape=(self.dim // 2,), dtype=tf.float32, initializer=get_initializer(1.0), trainable=False
+        )
+        self.inv_freq.assign(
+            1.0 / (10000 ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim))
+        )
+
+    def _compute_cos_sin(self, x, seq_dimension=2):
+        seq_len = tf.shape(x)[seq_dimension]
+
+        t = tf.range(seq_len, dtype=self.inv_freq.dtype)
+        freqs = tf.einsum("i, j -> ij", t, self.inv_freq)  # Outer multiplication
+        emb = tf.concat((freqs, freqs), axis=-1)[None, None, :, :]
+
+        return tf.cos(emb), tf.sin(emb)
+
+    def call(self, q: tf.Tensor, k: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
+        cos_emb, sin_emb = self._compute_cos_sin(k, seq_dimension=-2)
+
+        return (
+            apply_rotary_pos_emb(q, cos_emb, sin_emb),
+            apply_rotary_pos_emb(k, cos_emb, sin_emb),
+        )
+
+
+class TFEsmContactPredictionHead(keras.layers.Layer):
+    """Performs symmetrization, apc, and computes a logistic regression on the output features"""
+
+    def __init__(
+        self,
+        in_features: int,
+        bias=True,
+        eos_idx: int = 2,
+        name=None,
+    ):
+        super().__init__(name=name)
+        self.eos_idx = eos_idx
+        self.in_features = in_features
+        self.regression = keras.layers.Dense(1, use_bias=bias, activation="sigmoid", name="regression")
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "regression", None) is not None:
+            with tf.name_scope(self.regression.name):
+                self.regression.build((None, self.in_features))
+
+    def call(self, tokens, attentions):
+        # remove eos token attentions
+        eos_mask = tf.cast(tokens != self.eos_idx, attentions.dtype)
+        eos_mask = tf.expand_dims(eos_mask, 1) * tf.expand_dims(eos_mask, 2)
+        attentions = attentions * eos_mask[:, None, None, :, :]
+        attentions = attentions[..., :-1, :-1]
+        # remove cls token attentions
+        attentions = attentions[..., 1:, 1:]
+        batch_size, layers, heads, seqlen, _ = shape_list(attentions)
+        attentions = tf.reshape(attentions, (batch_size, layers * heads, seqlen, seqlen))
+
+        # features: batch x channels x tokens x tokens (symmetric)
+        attentions = average_product_correct(symmetrize(attentions))
+        attentions = tf.transpose(attentions, perm=(0, 2, 3, 1))
+        return tf.squeeze(self.regression(attentions), 3)
+
+
+class TFEsmEmbeddings(keras.layers.Layer):
+    """
+    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+    """
+
+    def __init__(self, config, name=None):
+        super().__init__(name=name)
+        self.word_embeddings = keras.layers.Embedding(
+            config.vocab_size,
+            config.hidden_size,
+            embeddings_initializer=get_initializer(config.initializer_range),
+            name="word_embeddings",
+        )
+        self.position_embeddings = keras.layers.Embedding(
+            config.max_position_embeddings,
+            config.hidden_size,
+            embeddings_initializer=get_initializer(config.initializer_range),
+            name="position_embeddings",
+        )
+
+        if config.emb_layer_norm_before:
+            self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
+        else:
+            self.layer_norm = None
+        # Matt: I think this line was copied incorrectly from BERT, disabling for now
+        # self.dropout = Dropout(config.hidden_dropout_prob)
+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+        self.position_ids = tf.range(config.max_position_embeddings)[None, :]
+
+        self.padding_idx = config.pad_token_id
+        self.token_dropout = config.token_dropout
+        self.mask_token_id = config.mask_token_id
+        self.config = config
+
+    def call(
+        self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+    ):
+        if position_ids is None:
+            if input_ids is not None:
+                # Create the position ids from the input token ids. Any padded tokens remain padded.
+                position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
+            else:
+                position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+        if inputs_embeds is None:
+            check_embeddings_within_bounds(input_ids, self.config.vocab_size)
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an
+        # embedding_scale factor here.
+        embeddings = inputs_embeds
+
+        # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
+        # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
+        # masked tokens are treated as if they were selected for input dropout and zeroed out.
+        # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
+        # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
+        # This is analogous to the way that dropout layers scale down outputs during evaluation when not
+        # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
+        if self.token_dropout:
+            embeddings = tf.where((input_ids == self.mask_token_id)[:, :, None], 0.0, embeddings)
+            mask_ratio_train = 0.15 * 0.8  # Hardcoded as the ratio used in all ESM model training runs
+            src_lengths = tf.cast(tf.reduce_sum(attention_mask, axis=-1), tf.float32)
+            masked_tokens = input_ids == self.mask_token_id
+            mask_ratio_observed = tf.math.count_nonzero(masked_tokens, dtype=tf.float32, axis=-1) / src_lengths
+            embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
+
+        if self.position_embedding_type == "absolute":
+            position_embeddings = self.position_embeddings(position_ids)
+            embeddings += position_embeddings
+
+        if self.layer_norm is not None:
+            embeddings = self.layer_norm(embeddings)
+        if attention_mask is not None:
+            embeddings = embeddings * tf.cast(tf.expand_dims(attention_mask, -1), embeddings.dtype)
+        # Matt: I think this line was copied incorrectly from BERT, disabling it for now.
+        # embeddings = self.dropout(embeddings)
+        return embeddings
+
+    def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+        """
+        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+
+        Args:
+            inputs_embeds: tf.Tensor
+
+        Returns: tf.Tensor
+        """
+        input_shape = shape_list(inputs_embeds)[:-1]
+        sequence_length = input_shape[1]
+
+        position_ids = tf.range(
+            start=self.padding_idx + 1, limit=sequence_length + self.padding_idx + 1, dtype=tf.int64
+        )
+        return tf.broadcast_to(tf.expand_dims(position_ids, 0), input_shape)
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "word_embeddings", None) is not None:
+            with tf.name_scope(self.word_embeddings.name):
+                self.word_embeddings.build(None)
+        if getattr(self, "position_embeddings", None) is not None:
+            with tf.name_scope(self.position_embeddings.name):
+                self.position_embeddings.build(None)
+        if getattr(self, "layer_norm", None) is not None:
+            with tf.name_scope(self.layer_norm.name):
+                self.layer_norm.build([None, None, self.config.hidden_size])
+
+
+class TFEsmSelfAttention(keras.layers.Layer):
+    def __init__(self, config, position_embedding_type=None, name=None):
+        super().__init__(name=name)
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+            raise ValueError(
+                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+                f"heads ({config.num_attention_heads})"
+            )
+
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+        self.query = keras.layers.Dense(
+            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
+        )
+        self.key = keras.layers.Dense(
+            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
+        )
+        self.value = keras.layers.Dense(
+            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
+        )
+
+        self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
+        self.position_embedding_type = position_embedding_type or getattr(
+            config, "position_embedding_type", "absolute"
+        )
+        self.rotary_embeddings = None
+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+            self.max_position_embeddings = config.max_position_embeddings
+            self.distance_embedding = keras.layers.Embedding(
+                2 * config.max_position_embeddings - 1,
+                self.attention_head_size,
+                embeddings_initializer=get_initializer(config.initializer_range),
+            )
+        elif self.position_embedding_type == "rotary":
+            self.rotary_embeddings = TFRotaryEmbedding(dim=self.attention_head_size, name="rotary_embeddings")
+
+        self.is_decoder = config.is_decoder
+        self.config = config
+
+    def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
+        new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size]
+        x = tf.reshape(x, new_x_shape)
+        return tf.transpose(x, perm=(0, 2, 1, 3))
+
+    def call(
+        self,
+        hidden_states: tf.Tensor,
+        attention_mask: tf.Tensor | None = None,
+        head_mask: tf.Tensor | None = None,
+        encoder_hidden_states: tf.Tensor | None = None,
+        encoder_attention_mask: tf.Tensor | None = None,
+        past_key_value: tuple[tuple[tf.Tensor]] | None = None,
+        output_attentions: bool | None = False,
+        training: bool = False,
+    ) -> tuple[tf.Tensor]:
+        mixed_query_layer = self.query(hidden_states)
+
+        # If this is instantiated as a cross-attention module, the keys
+        # and values come from an encoder; the attention mask needs to be
+        # such that the encoder's padding tokens are not attended to.
+        is_cross_attention = encoder_hidden_states is not None
+
+        if is_cross_attention and past_key_value is not None:
+            # reuse k,v, cross_attentions
+            key_layer = past_key_value[0]
+            value_layer = past_key_value[1]
+            attention_mask = encoder_attention_mask
+        elif is_cross_attention:
+            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+            attention_mask = encoder_attention_mask
+        elif past_key_value is not None:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+            key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
+            value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
+        else:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+        query_layer = self.transpose_for_scores(mixed_query_layer)
+
+        # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
+        # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
+        # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
+        # ESM code and fix rotary embeddings.
+        query_layer = query_layer * self.attention_head_size**-0.5
+
+        if self.is_decoder:
+            # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
+            # Further calls to cross_attention layer can then reuse all cross-attention
+            # key/value_states (first "if" case)
+            # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
+            # all previous decoder key/value_states. Further calls to uni-directional self-attention
+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+            # if encoder bi-directional self-attention `past_key_value` is always `None`
+            past_key_value = (key_layer, value_layer)
+
+        if self.position_embedding_type == "rotary":
+            query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+
+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+            seq_length = shape_list(hidden_states)[1]
+            position_ids_l = tf.expand_dims(tf.range(seq_length, dtype=tf.int64), -1)
+            position_ids_r = tf.expand_dims(tf.range(seq_length, dtype=tf.int64), 0)
+            distance = position_ids_l - position_ids_r
+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+            positional_embedding = tf.cast(positional_embedding, query_layer.dtype)  # fp16 compatibility
+
+            if self.position_embedding_type == "relative_key":
+                relative_position_scores = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+                attention_scores = attention_scores + relative_position_scores
+            elif self.position_embedding_type == "relative_key_query":
+                relative_position_scores_query = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+                relative_position_scores_key = tf.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+        if attention_mask is not None:
+            # Apply the attention mask is (precomputed for all layers in EsmModel forward() function)
+            attention_scores = attention_scores + attention_mask
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = stable_softmax(attention_scores, axis=-1)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs = self.dropout(attention_probs, training=training)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs = attention_probs * head_mask
+
+        context_layer = attention_probs @ value_layer
+
+        context_layer = tf.transpose(context_layer, perm=(0, 2, 1, 3))
+        new_context_layer_shape = shape_list(context_layer)[:-2] + [self.all_head_size]
+        context_layer = tf.reshape(context_layer, new_context_layer_shape)
+
+        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+        if self.is_decoder:
+            outputs = outputs + (past_key_value,)
+        return outputs
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "query", None) is not None:
+            with tf.name_scope(self.query.name):
+                self.query.build([None, None, self.config.hidden_size])
+        if getattr(self, "key", None) is not None:
+            with tf.name_scope(self.key.name):
+                self.key.build([None, None, self.config.hidden_size])
+        if getattr(self, "value", None) is not None:
+            with tf.name_scope(self.value.name):
+                self.value.build([None, None, self.config.hidden_size])
+        if getattr(self, "rotary_embeddings", None) is not None:
+            with tf.name_scope(self.rotary_embeddings.name):
+                self.rotary_embeddings.build(None)
+
+
+class TFEsmSelfOutput(keras.layers.Layer):
+    def __init__(self, config, name=None):
+        super().__init__(name=name)
+        self.dense = keras.layers.Dense(
+            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+        )
+        self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+        self.config = config
+
+    def call(self, hidden_states, input_tensor, training=False):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states, training=training)
+        hidden_states += input_tensor
+        return hidden_states
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "dense", None) is not None:
+            with tf.name_scope(self.dense.name):
+                self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFEsmAttention(keras.layers.Layer):
+    def __init__(self, config, name=None):
+        super().__init__(name=name)
+        self.self = TFEsmSelfAttention(config, name="self")
+        self.output_layer = TFEsmSelfOutput(config, name="output")
+        self.pruned_heads = set()
+        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+        self.config = config
+
+    def prune_heads(self, heads):
+        raise NotImplementedError
+
+    def call(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+        training=False,
+    ):
+        hidden_states_ln = self.LayerNorm(hidden_states)
+        self_outputs = self.self(
+            hidden_states_ln,
+            attention_mask,
+            head_mask,
+            encoder_hidden_states,
+            encoder_attention_mask,
+            past_key_value,
+            output_attentions,
+            training,
+        )
+        attention_output = self.output_layer(self_outputs[0], hidden_states)
+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
+        return outputs
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "self", None) is not None:
+            with tf.name_scope(self.self.name):
+                self.self.build(None)
+        if getattr(self, "output_layer", None) is not None:
+            with tf.name_scope(self.output_layer.name):
+                self.output_layer.build(None)
+        if getattr(self, "LayerNorm", None) is not None:
+            with tf.name_scope(self.LayerNorm.name):
+                self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+class TFEsmIntermediate(keras.layers.Layer):
+    def __init__(self, config: EsmConfig, **kwargs):
+        super().__init__(**kwargs)
+
+        self.dense = keras.layers.Dense(
+            units=config.intermediate_size,
+            kernel_initializer=get_initializer(config.initializer_range),
+            name="dense",
+        )
+        self.config = config
+
+    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+        hidden_states = self.dense(inputs=hidden_states)
+        hidden_states = tf.nn.gelu(hidden_states)
+        return hidden_states
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "dense", None) is not None:
+            with tf.name_scope(self.dense.name):
+                self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFEsmOutput(keras.layers.Layer):
+    def __init__(self, config, name=None):
+        super().__init__(name=name)
+        self.dense = keras.layers.Dense(
+            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+        )
+        self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+        self.config = config
+
+    def call(self, hidden_states, input_tensor, training=False):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states, training=training)
+        hidden_states += input_tensor
+        return hidden_states
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "dense", None) is not None:
+            with tf.name_scope(self.dense.name):
+                self.dense.build([None, None, self.config.intermediate_size])
+
+
+class TFEsmLayer(keras.layers.Layer):
+    def __init__(self, config, name=None):
+        super().__init__(name=name)
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.seq_len_dim = 1
+        self.attention = TFEsmAttention(config, name="attention")
+        self.is_decoder = config.is_decoder
+        self.add_cross_attention = config.add_cross_attention
+        if self.add_cross_attention:
+            if not self.is_decoder:
+                raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
+            self.crossattention = TFEsmAttention(config)
+        self.intermediate = TFEsmIntermediate(config, name="intermediate")
+        self.output_layer = TFEsmOutput(config, name="output")
+        self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+        self.config = config
+
+    def call(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+        training=False,
+    ):
+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+        self_attention_outputs = self.attention(
+            hidden_states,
+            attention_mask,
+            head_mask,
+            output_attentions=output_attentions,
+            past_key_value=self_attn_past_key_value,
+            training=training,
+        )
+        attention_output = self_attention_outputs[0]
+
+        # if decoder, the last output is tuple of self-attn cache
+        if self.is_decoder:
+            outputs = self_attention_outputs[1:-1]
+            present_key_value = self_attention_outputs[-1]
+        else:
+            outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
+
+        cross_attn_present_key_value = None
+        if self.is_decoder and encoder_hidden_states is not None:
+            if not hasattr(self, "crossattention"):
+                raise AttributeError(
+                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
+                    " with cross-attention layers by setting `config.add_cross_attention=True`"
+                )
+
+            # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+            cross_attention_outputs = self.crossattention(
+                attention_output,
+                attention_mask,
+                head_mask,
+                encoder_hidden_states,
+                encoder_attention_mask,
+                cross_attn_past_key_value,
+                output_attentions,
+                training=training,
+            )
+            attention_output = cross_attention_outputs[0]
+            outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights
+
+            # add cross-attn cache to positions 3,4 of present_key_value tuple
+            cross_attn_present_key_value = cross_attention_outputs[-1]
+            present_key_value = present_key_value + cross_attn_present_key_value
+
+        layernorm_output = self.LayerNorm(attention_output)
+        intermediate_output = self.intermediate(hidden_states=layernorm_output)
+        layer_output = self.output_layer(
+            hidden_states=intermediate_output, input_tensor=attention_output, training=training
+        )
+        outputs = (layer_output,) + outputs  # add attentions if we output them
+
+        # if decoder, return the attn key/values as the last output
+        if self.is_decoder:
+            outputs = outputs + (present_key_value,)
+
+        return outputs
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "attention", None) is not None:
+            with tf.name_scope(self.attention.name):
+                self.attention.build(None)
+        if getattr(self, "intermediate", None) is not None:
+            with tf.name_scope(self.intermediate.name):
+                self.intermediate.build(None)
+        if getattr(self, "output_layer", None) is not None:
+            with tf.name_scope(self.output_layer.name):
+                self.output_layer.build(None)
+        if getattr(self, "LayerNorm", None) is not None:
+            with tf.name_scope(self.LayerNorm.name):
+                self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+class TFEsmEncoder(keras.layers.Layer):
+    def __init__(self, config, name=None):
+        super().__init__(name=name)
+        self.config = config
+        self.layer = [TFEsmLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
+        self.emb_layer_norm_after = keras.layers.LayerNormalization(
+            epsilon=config.layer_norm_eps, name="emb_layer_norm_after"
+        )
+
+    def call(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_values=None,
+        use_cache=None,
+        output_attentions=False,
+        output_hidden_states=False,
+        return_dict=True,
+        training=False,
+    ):
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+        next_decoder_cache = () if use_cache else None
+        for i, layer_module in enumerate(self.layer):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+            past_key_value = past_key_values[i] if past_key_values is not None else None
+
+            layer_outputs = layer_module(
+                hidden_states,
+                attention_mask,
+                layer_head_mask,
+                encoder_hidden_states,
+                encoder_attention_mask,
+                past_key_value,
+                output_attentions,
+                training,
+            )
+
+            hidden_states = layer_outputs[0]
+            if use_cache:
+                next_decoder_cache += (layer_outputs[-1],)
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+                if self.config.add_cross_attention:
+                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+        if self.emb_layer_norm_after:
+            hidden_states = self.emb_layer_norm_after(hidden_states)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [
+                    hidden_states,
+                    next_decoder_cache,
+                    all_hidden_states,
+                    all_self_attentions,
+                    all_cross_attentions,
+                ]
+                if v is not None
+            )
+        return TFBaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_decoder_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+            cross_attentions=all_cross_attentions,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "emb_layer_norm_after", None) is not None:
+            with tf.name_scope(self.emb_layer_norm_after.name):
+                self.emb_layer_norm_after.build([None, None, self.config.hidden_size])
+        if getattr(self, "layer", None) is not None:
+            for layer in self.layer:
+                with tf.name_scope(layer.name):
+                    layer.build(None)
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Esm
+class TFEsmPooler(keras.layers.Layer):
+    def __init__(self, config: EsmConfig, **kwargs):
+        super().__init__(**kwargs)
+
+        self.dense = keras.layers.Dense(
+            units=config.hidden_size,
+            kernel_initializer=get_initializer(config.initializer_range),
+            activation="tanh",
+            name="dense",
+        )
+        self.config = config
+
+    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(inputs=first_token_tensor)
+
+        return pooled_output
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "dense", None) is not None:
+            with tf.name_scope(self.dense.name):
+                self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFEsmPreTrainedModel(TFPreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = EsmConfig
+    base_model_prefix = "esm"
+
+
+ESM_START_DOCSTRING = r"""
+
+    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a Keras [Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it as a
+    regular Keras model and refer to the TF/Keras documentation for all matters related to general usage and behavior.
+
+    Parameters:
+        config ([`EsmConfig`]): Model configuration class with all the parameters of the
+            model. Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ESM_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`tf.Tensor` of shape `({0})`):
+            Indices of input sequence tokens in the vocabulary.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        position_ids (`tf.Tensor` of shape `({0})`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.max_position_embeddings - 1]`.
+
+            [What are position IDs?](../glossary#position-ids)
+        head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+            model's internal embedding lookup matrix.
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+    "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
+    ESM_START_DOCSTRING,
+)
+class TFEsmMainLayer(keras.layers.Layer):
+    """
+
+    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+    cross-attention is added between the self-attention layers, following the architecture described in [Attention is
+    all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+    To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+    to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+    `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+    """
+
+    _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+    def __init__(self, config, add_pooling_layer=True, name=None, **kwargs):
+        super().__init__(name=name, **kwargs)
+
+        self.config = config
+        self.is_decoder = config.is_decoder
+
+        self.embeddings = TFEsmEmbeddings(config, name="embeddings")
+        self.encoder = TFEsmEncoder(config, name="encoder")
+        self.pooler = TFEsmPooler(config, name="pooler") if add_pooling_layer else None
+
+        self.contact_head = TFEsmContactPredictionHead(
+            in_features=self.config.num_hidden_layers * self.config.num_attention_heads, bias=True, name="contact_head"
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "embeddings", None) is not None:
+            with tf.name_scope(self.embeddings.name):
+                self.embeddings.build(None)
+        if getattr(self, "encoder", None) is not None:
+            with tf.name_scope(self.encoder.name):
+                self.encoder.build(None)
+        if getattr(self, "pooler", None) is not None:
+            with tf.name_scope(self.pooler.name):
+                self.pooler.build(None)
+        if getattr(self, "contact_head", None) is not None:
+            with tf.name_scope(self.contact_head.name):
+                self.contact_head.build(None)
+
+    def get_input_embeddings(self):
+        return self.embeddings.word_embeddings
+
+    def set_input_embeddings(self, value: tf.Variable):
+        self.embeddings.word_embeddings.weight = value
+        self.embeddings.vocab_size = shape_list(value)[0]
+
+    def _prune_heads(self, heads_to_prune):
+        raise NotImplementedError
+
+    def call(
+        self,
+        input_ids: TFModelInputType | None = None,
+        attention_mask: np.ndarray | tf.Tensor | None = None,
+        position_ids: np.ndarray | tf.Tensor | None = None,
+        head_mask: np.ndarray | tf.Tensor | None = None,
+        inputs_embeds: np.ndarray | tf.Tensor | None = None,
+        encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+        encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+        past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+        use_cache: bool | None = None,
+        output_attentions: bool | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        training: bool = False,
+    ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]:
+        if not self.config.is_decoder:
+            use_cache = False
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = shape_list(input_ids)
+        elif inputs_embeds is not None:
+            input_shape = shape_list(inputs_embeds)[:-1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        batch_size, seq_length = input_shape
+
+        if past_key_values is None:
+            past_key_values_length = 0
+            past_key_values = [None] * len(self.encoder.layer)
+        else:
+            past_key_values_length = shape_list(past_key_values[0][0])[-2]
+
+        if attention_mask is None:
+            attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
+
+        embedding_output = self.embeddings(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            inputs_embeds=inputs_embeds,
+            past_key_values_length=past_key_values_length,
+            training=training,
+        )
+
+        # We create a 3D attention mask from a 2D tensor mask.
+        # Sizes are [batch_size, 1, 1, to_seq_length]
+        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+        # this attention mask is more simple than the triangular masking of causal attention
+        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+        attention_mask_shape = shape_list(attention_mask)
+
+        mask_seq_length = seq_length + past_key_values_length
+        # Copied from `modeling_tf_t5.py`
+        # Provided a padding mask of dimensions [batch_size, mask_seq_length]
+        # - if the model is a decoder, apply a causal mask in addition to the padding mask
+        # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
+        if self.is_decoder:
+            seq_ids = tf.range(mask_seq_length)
+            causal_mask = tf.less_equal(
+                tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
+                seq_ids[None, :, None],
+            )
+            causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
+            extended_attention_mask = causal_mask * attention_mask[:, None, :]
+            attention_mask_shape = shape_list(extended_attention_mask)
+            extended_attention_mask = tf.reshape(
+                extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
+            )
+            if past_key_values[0] is not None:
+                # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
+                extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
+        else:
+            extended_attention_mask = tf.reshape(
+                attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
+            )
+
+        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+        # masked positions, this operation will create a tensor which is 0.0 for
+        # positions we want to attend and -10000.0 for masked positions.
+        # Since we are adding it to the raw scores before the softmax, this is
+        # effectively the same as removing these entirely.
+        extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
+        one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
+        ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
+        extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
+
+        # Copied from `modeling_tf_t5.py` with -1e9 -> -10000
+        if self.is_decoder and encoder_attention_mask is not None:
+            # If a 2D ou 3D attention mask is provided for the cross-attention
+            # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
+            # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+            encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype)
+            num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
+            if num_dims_encoder_attention_mask == 3:
+                encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
+            if num_dims_encoder_attention_mask == 2:
+                encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
+
+            # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
+            # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
+            # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
+            #                                         tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
+
+            encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
+        else:
+            encoder_extended_attention_mask = None
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        if head_mask is not None:
+            raise NotImplementedError
+        else:
+            head_mask = [None] * self.config.num_hidden_layers
+
+        encoder_outputs = self.encoder(
+            hidden_states=embedding_output,
+            attention_mask=extended_attention_mask,
+            head_mask=head_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_extended_attention_mask,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+
+        sequence_output = encoder_outputs[0]
+        pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
+
+        if not return_dict:
+            return (
+                sequence_output,
+                pooled_output,
+            ) + encoder_outputs[1:]
+
+        return TFBaseModelOutputWithPoolingAndCrossAttentions(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            past_key_values=encoder_outputs.past_key_values,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+            cross_attentions=encoder_outputs.cross_attentions,
+        )
+
+    def predict_contacts(self, tokens, attention_mask):
+        attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
+        attns = tf.stack(attns, axis=1)  # Matches the original model layout
+        # In the original model, attentions for padding tokens are completely zeroed out.
+        # This makes no difference most of the time because the other tokens won't attend to them,
+        # but it does for the contact prediction task, which takes attentions as input,
+        # so we have to mimic that here.
+        attention_mask = tf.cast(attention_mask, attns.dtype)
+        attns *= attention_mask[:, None, None, None]
+        attns *= attention_mask[:, None, None, :, None]
+        return self.contact_head(tokens, attns)
+
+
+@add_start_docstrings(
+    "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
+    ESM_START_DOCSTRING,
+)
+class TFEsmModel(TFEsmPreTrainedModel):
+    def __init__(self, config: EsmConfig, add_pooling_layer=True, *inputs, **kwargs):
+        super().__init__(config, *inputs, **kwargs)
+
+        self.esm = TFEsmMainLayer(config, add_pooling_layer=add_pooling_layer, name="esm")
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    def call(
+        self,
+        input_ids: TFModelInputType | None = None,
+        attention_mask: np.ndarray | tf.Tensor | None = None,
+        position_ids: np.ndarray | tf.Tensor | None = None,
+        head_mask: np.ndarray | tf.Tensor | None = None,
+        inputs_embeds: np.ndarray | tf.Tensor | None = None,
+        encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+        encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+        past_key_values: tuple[tuple[np.ndarray | tf.Tensor]] | None = None,
+        use_cache: bool | None = None,
+        output_attentions: bool | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        training: bool | None = False,
+    ) -> TFBaseModelOutputWithPoolingAndCrossAttentions | tuple[tf.Tensor]:
+        r"""
+        encoder_hidden_states  (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+            the model is configured as a decoder.
+        encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+            the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+        past_key_values (`tuple[tuple[tf.Tensor]]` of length `config.n_layers`)
+            contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+            `past_key_values`). Set to `False` during training, `True` during generation
+        """
+        outputs = self.esm(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+        return outputs
+
+    def predict_contacts(self, tokens, attention_mask):
+        return self.esm.predict_contacts(tokens, attention_mask)
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "esm", None) is not None:
+            with tf.name_scope(self.esm.name):
+                self.esm.build(None)
+
+
+@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
+class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
+    _keys_to_ignore_on_load_missing = [r"position_ids"]
+    _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        if config.is_decoder:
+            logger.warning(
+                "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
+                "bi-directional self-attention."
+            )
+
+        self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
+        self.lm_head = TFEsmLMHead(config, name="lm_head")
+        if config.tie_word_embeddings:
+            # Ensure word embeddings are built so that we actually have something to tie
+            with tf.name_scope(os.path.join(self._name_scope(), "esm", "embeddings", "word_embeddings")):
+                self.esm.embeddings.word_embeddings.build((None, None))
+            self.lm_head.decoder = self.esm.embeddings.word_embeddings.weights[0]
+
+    def get_output_embeddings(self):
+        return self.lm_head.decoder
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head.decoder = new_embeddings
+
+    def get_lm_head(self):
+        return self.lm_head
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=TFMaskedLMOutput,
+        config_class=_CONFIG_FOR_DOC,
+        mask="",
+    )
+    def call(
+        self,
+        input_ids: TFModelInputType | None = None,
+        attention_mask: np.ndarray | tf.Tensor | None = None,
+        position_ids: np.ndarray | tf.Tensor | None = None,
+        head_mask: np.ndarray | tf.Tensor | None = None,
+        inputs_embeds: np.ndarray | tf.Tensor | None = None,
+        encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+        encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+        labels: np.ndarray | tf.Tensor | None = None,
+        output_attentions: bool | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        training: bool = False,
+    ) -> TFMaskedLMOutput | tuple[tf.Tensor]:
+        r"""
+        labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+        kwargs (`dict[str, any]`, *optional*, defaults to `{}`):
+            Used to hide legacy arguments that have been deprecated.
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.esm(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+        sequence_output = outputs[0]
+        prediction_scores = self.lm_head(sequence_output)
+
+        masked_lm_loss = None
+        if labels is not None:
+            masked_lm_loss = self.hf_compute_loss(labels=labels, logits=prediction_scores)
+
+        if not return_dict:
+            output = (prediction_scores,) + outputs[2:]
+            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+        return TFMaskedLMOutput(
+            loss=masked_lm_loss,
+            logits=prediction_scores,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+    def predict_contacts(self, tokens, attention_mask):
+        return self.esm.predict_contacts(tokens, attention_mask)
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "esm", None) is not None:
+            with tf.name_scope(self.esm.name):
+                self.esm.build(None)
+        if getattr(self, "lm_head", None) is not None:
+            with tf.name_scope(self.lm_head.name):
+                self.lm_head.build(None)
+
+
+class TFEsmLMHead(keras.layers.Layer):
+    """ESM Head for masked language modeling."""
+
+    def __init__(self, config, name=None):
+        super().__init__(name=name)
+        self.dense = keras.layers.Dense(
+            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+        )
+
+        self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
+        if config.tie_word_embeddings:
+            self.decoder = None
+        else:
+            self.decoder = keras.layers.Dense(
+                config.vocab_size,
+                kernel_initializer=get_initializer(config.initializer_range),
+                name="decoder",
+                use_bias=False,
+            )
+        self.config = config
+
+    def build(self, input_shape=None):
+        # Separate bias to match the PT model and allow weight cross-loading to work
+        # Put it in the build so it gets the right name when adding it as a weight
+        if self.built:
+            return
+        self.built = True
+        self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
+        if getattr(self, "dense", None) is not None:
+            with tf.name_scope(self.dense.name):
+                self.dense.build([None, None, self.config.hidden_size])
+        if getattr(self, "layer_norm", None) is not None:
+            with tf.name_scope(self.layer_norm.name):
+                self.layer_norm.build([None, None, self.config.hidden_size])
+        if getattr(self, "decoder", None) is not None and not self.config.tie_word_embeddings:
+            with tf.name_scope(self.decoder.name):
+                self.decoder.build([None, None, self.config.hidden_size])
+
+    def get_bias(self):
+        return {"bias": self.bias}
+
+    def call(self, features):
+        x = self.dense(features)
+        x = tf.nn.gelu(x)
+        x = self.layer_norm(x)
+
+        # project back to size of vocabulary with bias
+        if self.config.tie_word_embeddings:
+            x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias
+        else:
+            x = self.decoder(x) + self.bias
+        return x
+
+
+@add_start_docstrings(
+    """
+    ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+    output) e.g. for GLUE tasks.
+    """,
+    ESM_START_DOCSTRING,
+)
+class TFEsmForSequenceClassification(TFEsmPreTrainedModel, TFSequenceClassificationLoss):
+    _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+        self.config = config
+
+        self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
+        self.classifier = TFEsmClassificationHead(config, name="classifier")
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=TFSequenceClassifierOutput,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    def call(
+        self,
+        input_ids: TFModelInputType | None = None,
+        attention_mask: np.ndarray | tf.Tensor | None = None,
+        position_ids: np.ndarray | tf.Tensor | None = None,
+        head_mask: np.ndarray | tf.Tensor | None = None,
+        inputs_embeds: np.ndarray | tf.Tensor | None = None,
+        labels: np.ndarray | tf.Tensor | None = None,
+        output_attentions: bool | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        training: bool = False,
+    ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]:
+        r"""
+        labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.esm(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+        sequence_output = outputs[0]
+        logits = self.classifier(sequence_output)
+
+        loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+        if not return_dict:
+            output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return TFSequenceClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "esm", None) is not None:
+            with tf.name_scope(self.esm.name):
+                self.esm.build(None)
+        if getattr(self, "classifier", None) is not None:
+            with tf.name_scope(self.classifier.name):
+                self.classifier.build(None)
+
+
+@add_start_docstrings(
+    """
+    ESM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+    Named-Entity-Recognition (NER) tasks.
+    """,
+    ESM_START_DOCSTRING,
+)
+class TFEsmForTokenClassification(TFEsmPreTrainedModel, TFTokenClassificationLoss):
+    _keys_to_ignore_on_load_unexpected = [r"pooler"]
+    _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+
+        self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
+        self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+        self.classifier = keras.layers.Dense(config.num_labels, name="classifier")
+        self.config = config
+
+    @unpack_inputs
+    @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=TFTokenClassifierOutput,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    def call(
+        self,
+        input_ids: TFModelInputType | None = None,
+        attention_mask: np.ndarray | tf.Tensor | None = None,
+        position_ids: np.ndarray | tf.Tensor | None = None,
+        head_mask: np.ndarray | tf.Tensor | None = None,
+        inputs_embeds: np.ndarray | tf.Tensor | None = None,
+        labels: np.ndarray | tf.Tensor | None = None,
+        output_attentions: bool | None = None,
+        output_hidden_states: bool | None = None,
+        return_dict: bool | None = None,
+        training: bool = False,
+    ) -> TFTokenClassifierOutput | tuple[tf.Tensor]:
+        r"""
+        labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.esm(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            training=training,
+        )
+
+        sequence_output = outputs[0]
+
+        sequence_output = self.dropout(sequence_output, training=training)
+        logits = self.classifier(sequence_output)
+
+        loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+        if not return_dict:
+            output = (logits,) + outputs[2:]
+            return ((loss,) + output) if loss is not None else output
+
+        return TFTokenClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "esm", None) is not None:
+            with tf.name_scope(self.esm.name):
+                self.esm.build(None)
+        if getattr(self, "classifier", None) is not None:
+            with tf.name_scope(self.classifier.name):
+                self.classifier.build([None, None, self.config.hidden_size])
+
+
+class TFEsmClassificationHead(keras.layers.Layer):
+    """Head for sentence-level classification tasks."""
+
+    def __init__(self, config, name=None):
+        super().__init__(name=name)
+        self.dense = keras.layers.Dense(
+            config.hidden_size,
+            kernel_initializer=get_initializer(config.initializer_range),
+            activation="tanh",
+            name="dense",
+        )
+        self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+        self.out_proj = keras.layers.Dense(
+            config.num_labels,
+            kernel_initializer=get_initializer(config.initializer_range),
+            activation="linear",
+            name="out_proj",
+        )
+        self.config = config
+
+    def call(self, features, training=False):
+        x = features[:, 0, :]  # take  token (equiv. to [CLS])
+        x = self.dropout(x, training=training)
+        x = self.dense(x)
+        x = self.dropout(x, training=training)
+        x = self.out_proj(x)
+        return x
+
+    def build(self, input_shape=None):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "dense", None) is not None:
+            with tf.name_scope(self.dense.name):
+                self.dense.build([None, None, self.config.hidden_size])
+        if getattr(self, "out_proj", None) is not None:
+            with tf.name_scope(self.out_proj.name):
+                self.out_proj.build([None, None, self.config.hidden_size])
+
+
+def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
+    """
+    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
+    are ignored. This is modified from fairseq's `utils.make_positions`.
+
+    Args:
+        x: tf.Tensor x:
+
+    Returns: tf.Tensor
+    """
+    # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+    mask = tf.cast(input_ids != padding_idx, tf.int64)
+    incremental_indices = (tf.cumsum(mask, axis=1) + past_key_values_length) * mask
+    return incremental_indices + padding_idx
+
+
+__all__ = [
+    "TFEsmForMaskedLM",
+    "TFEsmForSequenceClassification",
+    "TFEsmForTokenClassification",
+    "TFEsmModel",
+    "TFEsmPreTrainedModel",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/tokenization_esm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/tokenization_esm.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d9705f7dbd33216a327eab04415ec57fe8e858d
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/esm/tokenization_esm.py
@@ -0,0 +1,147 @@
+# coding=utf-8
+# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for ESM."""
+
+import os
+from typing import Optional
+
+from ...tokenization_utils import PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+
+def load_vocab_file(vocab_file):
+    with open(vocab_file, "r") as f:
+        lines = f.read().splitlines()
+        return [l.strip() for l in lines]
+
+
+class EsmTokenizer(PreTrainedTokenizer):
+    """
+    Constructs an ESM tokenizer.
+    """
+
+    vocab_files_names = VOCAB_FILES_NAMES
+    model_input_names = ["input_ids", "attention_mask"]
+
+    def __init__(
+        self,
+        vocab_file,
+        unk_token="",
+        cls_token="",
+        pad_token="",
+        mask_token="",
+        eos_token="",
+        **kwargs,
+    ):
+        self.all_tokens = load_vocab_file(vocab_file)
+        self._id_to_token = dict(enumerate(self.all_tokens))
+        self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)}
+        super().__init__(
+            unk_token=unk_token,
+            cls_token=cls_token,
+            pad_token=pad_token,
+            mask_token=mask_token,
+            eos_token=eos_token,
+            **kwargs,
+        )
+
+        # TODO, all the tokens are added? But they are also part of the vocab... bit strange.
+        # none of them are special, but they all need special splitting.
+
+        self.unique_no_split_tokens = self.all_tokens
+        self._update_trie(self.unique_no_split_tokens)
+
+    def _convert_id_to_token(self, index: int) -> str:
+        return self._id_to_token.get(index, self.unk_token)
+
+    def _convert_token_to_id(self, token: str) -> int:
+        return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
+
+    def _tokenize(self, text, **kwargs):
+        return text.split()
+
+    def get_vocab(self):
+        base_vocab = self._token_to_id.copy()
+        base_vocab.update(self.added_tokens_encoder)
+        return base_vocab
+
+    def token_to_id(self, token: str) -> int:
+        return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
+
+    def id_to_token(self, index: int) -> str:
+        return self._id_to_token.get(index, self.unk_token)
+
+    def build_inputs_with_special_tokens(
+        self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+    ) -> list[int]:
+        cls = [self.cls_token_id]
+        sep = [self.eos_token_id]  # No sep token in ESM vocabulary
+        if token_ids_1 is None:
+            if self.eos_token_id is None:
+                return cls + token_ids_0
+            else:
+                return cls + token_ids_0 + sep
+        elif self.eos_token_id is None:
+            raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!")
+        return cls + token_ids_0 + sep + token_ids_1 + sep  # Multiple inputs always have an EOS token
+
+    def get_special_tokens_mask(
+        self, token_ids_0: list, token_ids_1: Optional[list] = None, already_has_special_tokens: bool = False
+    ) -> list[int]:
+        """
+        Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+        special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
+
+        Args:
+            token_ids_0 (`list[int]`):
+                List of ids of the first sequence.
+            token_ids_1 (`list[int]`, *optional*):
+                List of ids of the second sequence.
+            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not the token list is already formatted with special tokens for the model.
+
+        Returns:
+            A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+        """
+        if already_has_special_tokens:
+            if token_ids_1 is not None:
+                raise ValueError(
+                    "You should not supply a second sequence if the provided sequence of "
+                    "ids is already formatted with special tokens for the model."
+                )
+
+            return [1 if token in self.all_special_ids else 0 for token in token_ids_0]
+        mask = [1] + ([0] * len(token_ids_0)) + [1]
+        if token_ids_1 is not None:
+            mask += [0] * len(token_ids_1) + [1]
+        return mask
+
+    def save_vocabulary(self, save_directory, filename_prefix):
+        vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt")
+        with open(vocab_file, "w") as f:
+            f.write("\n".join(self.all_tokens))
+        return (vocab_file,)
+
+    @property
+    def vocab_size(self) -> int:
+        return len(self.all_tokens)
+
+
+__all__ = ["EsmTokenizer"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..202147c938465dd7dfcb7e79ecbeeb93ce632dbf
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_falcon_mamba import *
+    from .modeling_falcon_mamba import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/configuration_falcon_mamba.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/configuration_falcon_mamba.py
new file mode 100644
index 0000000000000000000000000000000000000000..7630ebd6343ac968303fc0c31f2742bb352b4f8a
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/configuration_falcon_mamba.py
@@ -0,0 +1,170 @@
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+#           This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py.
+#               Do NOT edit this file manually as any edits will be overwritten by the generation of
+#             the file from the modular. If any change should be done, please apply the change to the
+#                          modular_falcon_mamba.py file directly. One of our CI enforces this.
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+from ...configuration_utils import PretrainedConfig
+
+
+class FalconMambaConfig(PretrainedConfig):
+    """
+    This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA
+    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+    defaults will yield a similar configuration to that of the FALCON_MAMBA
+    [tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 50280):
+            Vocabulary size of the FALCON_MAMBA model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`FalconMambaModel`].
+        hidden_size (`int`, *optional*, defaults to 768):
+            Dimensionality of the embeddings and hidden states.
+        state_size (`int`, *optional*, defaults to 16): shape of the state space latents.
+        num_hidden_layers (`int`, *optional*, defaults to 32):
+            Number of hidden layers in the model.
+        layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
+            The epsilon to use in the layer normalization layers.
+        pad_token_id (`int`, *optional*, defaults to 0):
+            Padding token id.
+        bos_token_id (`int`, *optional*, defaults to 0):
+            The id of the beginning of sentence token in the vocabulary.
+        eos_token_id (`int`, *optional*, defaults to 0):
+            The id of the end of sentence token in the vocabulary.
+        expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
+        conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
+        use_bias (`bool`, *optional*, defaults to `False`):
+            Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
+        use_conv_bias (`bool`, *optional*, defaults to `True`):
+            Whether or not to use bias in the convolution layer of the mixer block.
+        hidden_act (`str`, *optional*, defaults to `"silu"`):
+            The non-linear activation function (function or string) in the decoder.
+        initializer_range (`float`, *optional*, defaults to 0.1):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        residual_in_fp32 (`bool`, *optional*, defaults to `True`):
+            Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
+        time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
+            Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
+        time_step_scale (`float`, *optional*, defaults to 1.0):
+            Scale used used to scale `dt_proj.bias`.
+        time_step_min (`float`, *optional*, defaults to 0.001):
+            Minimum `time_step` used to bound `dt_proj.bias`.
+        time_step_max (`float`, *optional*, defaults to 0.1):
+            Maximum `time_step` used to bound `dt_proj.bias`.
+        time_step_init_scheme (`float`, *optional*, defaults to `"random"`):
+            Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`
+        time_step_floor (`float`, *optional*, defaults to 0.0001):
+            Minimum clamping value of the `dt_proj.bias` layer initialization.
+        rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
+            Whether or not to rescale `out_proj` weights when initializing.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the cache should be used.
+        use_falcon_mambapy (`bool`, *optional*, defaults to `False`):
+            This argument corresponds to `use_mambapy` in MambaConfig.
+            Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not available. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
+        mixer_rms_eps (`float`, *optional*, defaults to 1e-06):
+            The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states.
+
+
+    Example:
+
+    ```python
+    >>> from transformers import FalconMambaConfig, FalconMambaModel
+
+    >>> # Initializing a FalconMamba configuration
+    >>> configuration = FalconMambaConfig()
+
+    >>> # Initializing a model (with random weights) from the configuration
+    >>> model = FalconMambaModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "falcon_mamba"
+
+    def __init__(
+        self,
+        vocab_size=50280,
+        hidden_size=768,
+        state_size=16,
+        num_hidden_layers=32,
+        layer_norm_epsilon=1e-5,
+        pad_token_id=0,
+        bos_token_id=0,
+        eos_token_id=0,
+        expand=2,
+        conv_kernel=4,
+        use_bias=False,
+        use_conv_bias=True,
+        hidden_act="silu",
+        initializer_range=0.1,
+        residual_in_fp32=True,
+        time_step_rank="auto",
+        time_step_scale=1.0,
+        time_step_min=0.001,
+        time_step_max=0.1,
+        time_step_init_scheme="random",
+        time_step_floor=1e-4,
+        rescale_prenorm_residual=False,
+        use_cache=True,
+        use_falcon_mambapy=False,
+        mixer_rms_eps=1e-6,
+        **kwargs,
+    ):
+        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
+        self.vocab_size = vocab_size
+        self.hidden_size = hidden_size
+        self.state_size = state_size
+        self.num_hidden_layers = num_hidden_layers
+        self.layer_norm_epsilon = layer_norm_epsilon
+        self.conv_kernel = conv_kernel
+        self.expand = expand
+        # This is needed since mamba overrides the intermediate_size attribute
+        self.intermediate_size = (
+            int(expand * self.hidden_size)
+            if kwargs.get("intermediate_size") is None
+            else kwargs.get("intermediate_size")
+        )
+        self.bos_token_id = bos_token_id
+        self.eos_token_id = eos_token_id
+        self.pad_token_id = pad_token_id
+        self.use_bias = use_bias
+        self.use_conv_bias = use_conv_bias
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
+        self.time_step_scale = time_step_scale
+        self.time_step_min = time_step_min
+        self.time_step_max = time_step_max
+        self.time_step_init_scheme = time_step_init_scheme
+        self.time_step_floor = time_step_floor
+        self.rescale_prenorm_residual = rescale_prenorm_residual
+        self.residual_in_fp32 = residual_in_fp32
+        self.use_cache = use_cache
+        self.use_falcon_mambapy = use_falcon_mambapy
+        self.mixer_rms_eps = mixer_rms_eps
+
+
+__all__ = ["FalconMambaConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/modeling_falcon_mamba.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cdf6da7bda376a34dd00545d96c44a99fc0e660
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/modeling_falcon_mamba.py
@@ -0,0 +1,937 @@
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+#           This file was automatically generated from src/transformers/models/falcon_mamba/modular_falcon_mamba.py.
+#               Do NOT edit this file manually as any edits will be overwritten by the generation of
+#             the file from the modular. If any change should be done, please apply the change to the
+#                          modular_falcon_mamba.py file directly. One of our CI enforces this.
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Any, Optional, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...configuration_utils import PretrainedConfig
+from ...generation import GenerationMixin
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_utils import PreTrainedModel
+from ...utils import ModelOutput, auto_docstring, logging
+from ...utils.import_utils import (
+    is_causal_conv1d_available,
+    is_kernels_available,
+    is_mamba_ssm_available,
+    is_mambapy_available,
+)
+from .configuration_falcon_mamba import FalconMambaConfig
+
+
+if is_mambapy_available():
+    from mambapy.pscan import pscan
+else:
+    pscan = None
+
+if is_mamba_ssm_available():
+    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
+    from mamba_ssm.ops.triton.selective_state_update import selective_state_update
+
+    from ...kernels.falcon_mamba import mamba_inner_fn
+else:
+    selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
+
+
+logger = logging.get_logger(__name__)
+
+
+class FalconMambaCache:
+    """
+    Cache for falcon_mamba model which does not have attention mechanism and key value states.
+
+    Arguments:
+        config (`PretrainedConfig):
+            The configuration file defining the shape-related attributes required to initialize the static cache.
+        max_batch_size (`int`):
+            The maximum batch size with which the model will be used. Note that a new instance must be instantiated if
+            a smaller batch size is used.
+        dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
+            The default `dtype` to use when initializing the layer.
+        device (`torch.device` or `str`, *optional*):
+            The device on which the cache should be initialized. Should be the same as the layer.
+
+    Example:
+
+        ```python
+        >>> import torch
+        >>> from transformers import AutoTokenizer, FalconMambaForCausalLM, FalconMambaCache
+
+        >>> model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b")
+        >>> tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b")
+
+        >>> inputs = tokenizer(text="My name is FalconMamba", return_tensors="pt")
+
+        >>> # Prepare a cache class and pass it to model's forward
+        >>> cache_params = FalconMambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
+        >>> cache_position = torch.arange(len(inputs["input_ids"][0]), device=model.device)  # sequence length
+        >>> outputs = model(**inputs, cache_params=cache_params, cache_position=cache_position, use_cache=True)
+        >>> outputs.cache_params
+        ```
+    """
+
+    is_compileable = True
+
+    # TODO (joao): add layer_device_map arg and update code in `generate` accordingly
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        max_batch_size: int,
+        dtype: torch.dtype = torch.float16,
+        device: Union[torch.device, str, None] = None,
+    ):
+        self.max_batch_size = max_batch_size
+        self._dtype = dtype
+        self.intermediate_size = config.intermediate_size
+        self.ssm_state_size = config.state_size
+        self.conv_kernel_size = config.conv_kernel
+
+        self.conv_states: list[torch.Tensor] = []
+        self.ssm_states: list[torch.Tensor] = []
+        device = torch.device(device) if device is not None else None
+        for _ in range(config.num_hidden_layers):
+            conv_state: torch.Tensor = torch.zeros(
+                self.max_batch_size,
+                self.intermediate_size,
+                self.conv_kernel_size,
+                device=device,
+                dtype=self._dtype,
+            )
+            ssm_state: torch.Tensor = torch.zeros(
+                self.max_batch_size,
+                self.intermediate_size,
+                self.ssm_state_size,
+                device=device,
+                dtype=self._dtype,
+            )
+
+            torch._dynamo.mark_static_address(conv_state)
+            torch._dynamo.mark_static_address(ssm_state)
+            self.conv_states.append(conv_state)
+            self.ssm_states.append(ssm_state)
+
+    def update_conv_state(
+        self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
+    ) -> torch.Tensor:
+        # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used
+        # when the cache is initialized in the forward pass (e.g. FalconMamba)
+        if self.conv_states[layer_idx].device != new_conv_state.device:
+            self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device)
+
+        conv_state = self.conv_states[layer_idx]
+        cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
+
+        conv_state = conv_state.roll(shifts=-1, dims=-1)
+        conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
+        self.conv_states[layer_idx].zero_()
+        self.conv_states[layer_idx] += conv_state
+        return self.conv_states[layer_idx]
+
+    def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
+        self.ssm_states[layer_idx].zero_()
+        self.ssm_states[layer_idx] += new_ssm_state.to(self.ssm_states[layer_idx].device)
+        return self.ssm_states[layer_idx]
+
+    def reset(self):
+        for layer_idx in range(len(self.conv_states)):
+            # In-place ops prevent breaking the static address
+            self.conv_states[layer_idx].zero_()
+            self.ssm_states[layer_idx].zero_()
+
+
+def _lazy_load_causal_conv1d():
+    global _causal_conv1d_cache
+    if _causal_conv1d_cache is not None:
+        return _causal_conv1d_cache
+
+    if is_kernels_available():
+        from kernels import get_kernel
+
+        _causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
+        _causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
+    elif is_causal_conv1d_available():
+        from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
+
+        _causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
+    else:
+        _causal_conv1d_cache = (None, None)
+    return _causal_conv1d_cache
+
+
+_causal_conv1d_cache = None
+
+
+def rms_forward(hidden_states, variance_epsilon=1e-6):
+    """
+    Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
+    leverage this in order to multiply the final result with the RMSNorm weight
+
+    Args:
+        hidden_states (`torch.Tensor`):
+            Hidden states to normalize
+        variance_epsilon (`float`):
+            The eps value to add in the square root scaling factor
+    """
+    input_dtype = hidden_states.dtype
+    hidden_states = hidden_states.to(torch.float32)
+
+    variance = hidden_states.pow(2).mean(-1, keepdim=True)
+    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
+    return hidden_states.to(input_dtype)
+
+
+class FalconMambaMixer(nn.Module):
+    """
+    Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
+    A, D are input independent (see FalconMamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
+    ∆, B, C are input-dependent (this is a key difference between FalconMamba and the linear time invariant S4,
+    and is why FalconMamba is called **selective** state spaces)
+    """
+
+    def __init__(self, config: FalconMambaConfig, layer_idx: int):
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size
+        self.ssm_state_size = config.state_size
+        self.conv_kernel_size = config.conv_kernel
+        self.intermediate_size = config.intermediate_size
+        self.time_step_rank = int(config.time_step_rank)
+        self.layer_idx = layer_idx
+        self.use_conv_bias = config.use_conv_bias
+        self.conv1d = nn.Conv1d(
+            in_channels=self.intermediate_size,
+            out_channels=self.intermediate_size,
+            bias=config.use_conv_bias,
+            kernel_size=config.conv_kernel,
+            groups=self.intermediate_size,
+            padding=config.conv_kernel - 1,
+        )
+
+        self.activation = config.hidden_act
+        self.act = ACT2FN[config.hidden_act]
+
+        self.use_falcon_mambapy = config.use_falcon_mambapy
+
+        # projection of the input hidden states
+        self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
+        # selective projection used to make dt, B and C input dependent
+        self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
+        # time step projection (discretization)
+        self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
+
+        # S4D real initialization. These are not discretized!
+        # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
+        A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
+        A = A.expand(self.intermediate_size, -1).contiguous()
+
+        self.A_log = nn.Parameter(torch.log(A))
+        self.D = nn.Parameter(torch.ones(self.intermediate_size))
+        self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
+        self.use_bias = config.use_bias
+
+        self.warn_slow_implementation()
+        # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
+        self.register_buffer(
+            "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
+        )
+        self.register_buffer(
+            "dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False
+        )
+        self.rms_eps = config.mixer_rms_eps
+
+    def warn_slow_implementation(self):
+        causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
+        is_fast_path_available = all(
+            (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
+        )
+        if not is_fast_path_available:
+            if self.use_falcon_mambapy:
+                if is_mambapy_available():
+                    logger.warning_once(
+                        "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
+                        " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
+                        " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d"
+                    )
+                else:
+                    raise ImportError(
+                        "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py."
+                    )
+            else:
+                logger.warning_once(
+                    "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
+                    " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
+                    " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
+                )
+
+    def cuda_kernels_forward(
+        self,
+        hidden_states: torch.Tensor,
+        cache_params: Optional[FalconMambaCache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+    ):
+        # 1. Gated MLP's linear projection
+        projected_states = self.in_proj(hidden_states).transpose(1, 2)
+
+        if self.training and cache_params is None:  # Doesn't support outputting the states -> used for training
+            contextualized_states = mamba_inner_fn(
+                projected_states,
+                self.conv1d.weight,
+                self.conv1d.bias if self.use_conv_bias else None,
+                self.x_proj.weight,
+                self.dt_proj.weight,
+                self.out_proj.weight,
+                self.out_proj.bias.float() if self.use_bias else None,
+                -torch.exp(self.A_log.float()),
+                None,  # input-dependent B
+                None,  # input-dependent C
+                self.D.float(),
+                delta_bias=self.dt_proj.bias.float(),
+                delta_softplus=True,
+                b_rms_weight=self.b_c_rms,
+                c_rms_weight=self.b_c_rms,
+                dt_rms_weight=self.dt_rms,
+                b_c_dt_rms_eps=self.rms_eps,
+            )
+
+        else:
+            causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
+            hidden_states, gate = projected_states.chunk(2, dim=1)
+
+            if attention_mask is not None:
+                hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+            # 2. Convolution sequence transformation
+            conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
+            if cache_params is not None and cache_position[0] > 0:
+                hidden_states = causal_conv1d_update(
+                    hidden_states.squeeze(-1),
+                    cache_params.conv_states[self.layer_idx],
+                    conv_weights,
+                    self.conv1d.bias,
+                    self.activation,
+                )
+                hidden_states = hidden_states.unsqueeze(-1)
+            else:
+                if cache_params is not None:
+                    conv_states = nn.functional.pad(
+                        hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
+                    )
+                    cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
+                hidden_states = causal_conv1d_fn(
+                    hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
+                )
+
+            if attention_mask is not None:
+                hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+            # 3. State Space Model sequence transformation
+            # 3.a. input varying initialization of time_step, B and C
+            ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
+            time_step, B, C = torch.split(
+                ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
+            )
+
+            B = rms_forward(B, variance_epsilon=self.rms_eps)
+            C = rms_forward(C, variance_epsilon=self.rms_eps)
+            time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
+
+            # In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
+            # at the price of a small overhead.
+            if hasattr(self.config, "_pre_quantization_dtype"):
+                discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
+            else:
+                discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
+
+            A = -torch.exp(self.A_log.float())
+            # 3.c perform the recurrence y ← SSM(A, B, C)(x)
+            time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
+            if cache_params is not None and cache_position[0] > 0:
+                scan_outputs = selective_state_update(
+                    cache_params.ssm_states[self.layer_idx],
+                    hidden_states[..., 0],
+                    discrete_time_step[..., 0],
+                    A,
+                    B[:, 0],
+                    C[:, 0],
+                    self.D,
+                    gate[..., 0],
+                    time_proj_bias,
+                    dt_softplus=True,
+                ).unsqueeze(-1)
+            else:
+                scan_outputs, ssm_state = selective_scan_fn(
+                    hidden_states,
+                    discrete_time_step,
+                    A,
+                    B.transpose(1, 2),
+                    C.transpose(1, 2),
+                    self.D.float(),
+                    gate,
+                    time_proj_bias,
+                    delta_softplus=True,
+                    return_last_state=True,
+                )
+                if ssm_state is not None and cache_params is not None:
+                    cache_params.update_ssm_state(self.layer_idx, ssm_state)
+
+            # 4. Final linear projection
+            contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
+        return contextualized_states
+
+    # fmt: off
+    def slow_forward(self,
+        input_states,
+        cache_params: Optional[FalconMambaCache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+    ):
+        batch_size, seq_len, _ = input_states.shape
+        dtype = input_states.dtype
+        # 1. Gated MLP's linear projection
+        projected_states = self.in_proj(input_states).transpose(1, 2)  # [batch, 2 * intermediate_size, seq_len]
+        hidden_states, gate = projected_states.chunk(2, dim=1)
+
+        if attention_mask is not None:
+            hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+        # 2. Convolution sequence transformation
+        if cache_params is not None:
+            ssm_state = cache_params.ssm_states[self.layer_idx].clone()
+            ssm_state = ssm_state.to(hidden_states.device)
+            # use `cache_position.shape[0]` to check whether we are in prefill
+            # stage, it's equivalent to check `cache_position[0] == 0`, which
+            # breaks dynamo fullgraph constraints
+            if cache_position is not None and cache_position.shape[0] == self.conv_kernel_size:
+                conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
+
+                cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
+                hidden_states = self.act(
+                    self.conv1d(hidden_states)[..., :seq_len]
+                )  # [batch, intermediate_size, seq_len]
+            else:
+                conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
+                conv_state = conv_state.to(self.conv1d.weight.device)
+                hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
+                if self.use_conv_bias:
+                    hidden_states += self.conv1d.bias
+                hidden_states = (
+                    self.act(hidden_states).to(dtype).unsqueeze(-1)
+                )  # [batch, intermediate_size, 1] : decoding
+        else:
+            ssm_state = torch.zeros(
+                (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype
+            )
+            hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])  # [batch, intermediate_size, seq_len]
+
+        if attention_mask is not None:
+            hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+        # 3. State Space Model sequence transformation
+        # 3.a. Selection:  [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
+        ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
+        time_step, B, C = torch.split(
+            ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
+        )
+
+        B = rms_forward(B, variance_epsilon=self.rms_eps)
+        C = rms_forward(C, variance_epsilon=self.rms_eps)
+        time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
+
+        discrete_time_step = self.dt_proj(time_step)  # [batch, seq_len, intermediate_size]
+        discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(
+            1, 2
+        )  # [batch, intermediate_size, seq_len]
+
+        # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
+        A = -torch.exp(self.A_log.float())  # [intermediate_size, ssm_state_size]
+        discrete_A = torch.exp(
+            A[None, :, None, :] * discrete_time_step[:, :, :, None]
+        )  # [batch, intermediate_size, seq_len, ssm_state_size]
+        discrete_B = (
+            discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
+        )  # [batch, intermediate_size, seq_len, ssm_state_size]
+        deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
+
+        # 3.c perform the recurrence y ← SSM(A, B, C)(x)
+        if self.use_falcon_mambapy and self.training and cache_params is None:
+            hs = pscan(
+                discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)
+            )  # [batch, seq_len, intermediate_size, ssm_state_size]
+            scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2)  # [batch, intermediate_size, seq_len]
+            scan_output = scan_output + hidden_states * self.D[None, :, None]
+            scan_output = scan_output * self.act(gate)
+        else:
+            scan_outputs = []
+            for i in range(seq_len):
+                ssm_state = (
+                    discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
+                )  # [batch, intermediate_size, ssm_state]
+                scan_output = torch.matmul(
+                    ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)
+                )  # [batch, intermediate_size, 1]
+                scan_outputs.append(scan_output[:, :, 0])
+            scan_output = torch.stack(scan_outputs, dim=-1)  # [batch, intermediate_size, seq_len]
+            scan_output = scan_output + (hidden_states * self.D[None, :, None])
+            scan_output = scan_output * self.act(gate)
+
+            if cache_params is not None:
+                cache_params.update_ssm_state(self.layer_idx, ssm_state)
+
+        # 4. Final linear projection
+        contextualized_states = self.out_proj(scan_output.transpose(1, 2))  # [batch, seq_len, hidden_size]
+        return contextualized_states
+    # fmt: on
+
+    def forward(
+        self,
+        hidden_states,
+        cache_params: Optional[FalconMambaCache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+    ):
+        causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
+        is_fast_path_available = all(
+            (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
+        )
+        if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
+            return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
+        return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
+
+
+class FalconMambaRMSNorm(nn.Module):
+    def __init__(self, hidden_size, eps=1e-6):
+        """
+        FalconMambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
+        """
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states):
+        return self.weight.to(hidden_states.device) * rms_forward(
+            hidden_states, variance_epsilon=self.variance_epsilon
+        )
+
+    def extra_repr(self):
+        return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"
+
+
+class FalconMambaBlock(GradientCheckpointingLayer):
+    def __init__(self, config, layer_idx):
+        super().__init__()
+        self.config = config
+        self.layer_idx = layer_idx
+        self.residual_in_fp32 = config.residual_in_fp32
+        self.norm = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+        self.mixer = FalconMambaMixer(config, layer_idx=layer_idx)
+
+    def forward(
+        self,
+        hidden_states,
+        cache_params: Optional[FalconMambaCache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+    ):
+        residual = hidden_states
+        hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
+        if self.residual_in_fp32:
+            residual = residual.to(torch.float32)
+
+        hidden_states = self.mixer(
+            hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
+        )
+        hidden_states = residual + hidden_states
+        return hidden_states
+
+
+@auto_docstring
+class FalconMambaPreTrainedModel(PreTrainedModel):
+    config: FalconMambaConfig
+    base_model_prefix = "backbone"
+    _no_split_modules = ["FalconMambaBlock", "FalconMambaMixer"]
+    supports_gradient_checkpointing = True
+    _is_stateful = True
+
+    def _init_weights(self, module):
+        """Initialize the weights."""
+        std = self.config.initializer_range
+        if isinstance(module, FalconMambaMixer):
+            # S4D real initialization. These are not discretized!
+            # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
+            A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :]
+            A = A.expand(module.intermediate_size, -1).contiguous()
+            module.A_log.copy_(torch.log(A))
+            module.D.data.fill_(1.0)
+
+            dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
+            if self.config.time_step_init_scheme == "constant":
+                nn.init.constant_(module.dt_proj.weight, dt_init_std)
+            elif self.config.time_step_init_scheme == "random":
+                nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
+
+            dt = torch.exp(
+                torch.rand(self.config.intermediate_size)
+                * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
+                + math.log(self.config.time_step_min)
+            ).clamp(min=self.config.time_step_floor)
+            # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
+            inv_dt = dt + torch.log(-torch.expm1(-dt))
+            module.dt_proj.bias.copy_(inv_dt)
+            module.dt_proj.bias._no_reinit = True
+
+            nn.init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5))
+            if module.conv1d.bias is not None:
+                if not getattr(module.conv1d.bias, "_no_reinit", False):
+                    nn.init.zeros_(module.conv1d.bias)
+            nn.init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5))
+
+            if self.config.rescale_prenorm_residual:
+                # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
+                #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+                #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
+                #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
+                #
+                # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
+                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
+                # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
+                # We need to reinit p since this code could be called multiple times
+                # Having just p *= scale would repeatedly scale it down
+                p = module.out_proj.weight
+                p /= math.sqrt(self.config.num_hidden_layers)
+
+        if isinstance(module, nn.Linear):
+            if not getattr(module.weight, "_no_reinit", False):
+                nn.init.normal_(module.weight, std=std)
+            if module.bias is not None:
+                if not getattr(module.bias, "_no_reinit", False):
+                    nn.init.zeros_(module.bias)
+        elif isinstance(module, FalconMambaRMSNorm):
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, nn.Embedding):
+            nn.init.normal_(module.weight, std=std)
+
+
+@dataclass
+@auto_docstring(
+    custom_intro="""
+    Class for the FALCON_MAMBA model outputs.
+    """
+)
+class FalconMambaOutput(ModelOutput):
+    r"""
+    cache_params (`FalconMambaCache`):
+        The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
+        avoid providing the old `input_ids`.
+
+        Includes both the State space model state matrices after the selective scan, and the Convolutional states
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    cache_params: Optional[FalconMambaCache] = None
+    hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+@auto_docstring(
+    custom_intro="""
+    Base class for causal language model (or autoregressive) outputs.
+    """
+)
+class FalconMambaCausalLMOutput(ModelOutput):
+    r"""
+    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+        Language modeling loss (for next-token prediction).
+    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+    cache_params (`FalconMambaCache`):
+        The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
+        avoid providing the old `input_ids`.
+
+        Includes both the State space model state matrices after the selective scan, and the Convolutional states
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    cache_params: Optional[FalconMambaCache] = None
+    hidden_states: Optional[tuple[torch.FloatTensor]] = None
+
+
+@auto_docstring
+class FalconMambaModel(FalconMambaPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
+        self.layers = nn.ModuleList(
+            [FalconMambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]
+        )
+
+        self.gradient_checkpointing = False
+        self.norm_f = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embeddings
+
+    def set_input_embeddings(self, new_embeddings):
+        self.embeddings = new_embeddings
+
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.LongTensor] = None,
+        cache_params: Optional[FalconMambaCache] = None,
+        use_cache: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+    ) -> Union[tuple, FalconMambaOutput]:
+        r"""
+        cache_params (`FalconMambaCache`, *optional*):
+            If passed along, the model uses the previous state in all the blocks (which will give the output for the
+            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
+        use_cache (`bool`, *optional*):
+            If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
+        """
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if (input_ids is None) ^ (inputs_embeds is not None):  # ^ is python for xor
+            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embeddings(input_ids)
+
+        if self.gradient_checkpointing and self.training and use_cache:
+            use_cache = False
+
+        if use_cache:
+            if cache_params is None:
+                cache_params = FalconMambaCache(
+                    self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
+                )
+                cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
+            elif cache_position is None:
+                # cases when we do manual forward instead of using `model.generate` which will initiate
+                # `cache_position` and makes sure it is not None, throw error here instead of doing some
+                # hack to conjecture the current cache position
+                raise ValueError(
+                    "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
+                    "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
+                    "be initialized for you automatically"
+                )
+        else:
+            cache_params = None
+
+        hidden_states = inputs_embeds
+        all_hidden_states = () if output_hidden_states else None
+        for mixer_block in self.layers:
+            hidden_states = mixer_block(
+                hidden_states,
+                cache_params=cache_params,
+                cache_position=cache_position,
+                attention_mask=attention_mask,
+            )
+
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+        hidden_states = self.norm_f(hidden_states)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
+
+        return FalconMambaOutput(
+            last_hidden_state=hidden_states,
+            cache_params=cache_params if use_cache else None,
+            hidden_states=all_hidden_states,
+        )
+
+
+@auto_docstring(
+    custom_intro="""
+    The FALCON_MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input
+    embeddings).
+    """
+)
+class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
+    _tied_weights_keys = ["lm_head.weight"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.backbone = FalconMambaModel(config)
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.backbone.get_input_embeddings()
+
+    def set_input_embeddings(self, new_embeddings):
+        return self.backbone.set_input_embeddings(new_embeddings)
+
+    def _update_model_kwargs_for_generation(
+        self, outputs: ModelOutput, model_kwargs: dict[str, Any], num_new_tokens: int = 1, **kwargs
+    ) -> dict[str, Any]:
+        model_kwargs["cache_params"] = outputs.get("cache_params", None)
+        if (
+            model_kwargs.get("use_cache", True)
+            and "cache_position" in model_kwargs
+            and model_kwargs["cache_position"] is not None
+        ):
+            model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
+
+        if "attention_mask" in model_kwargs:
+            attention_mask = model_kwargs["attention_mask"]
+            model_kwargs["attention_mask"] = torch.cat(
+                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+            )
+
+        return model_kwargs
+
+    def prepare_inputs_for_generation(
+        self,
+        input_ids,
+        inputs_embeds=None,
+        use_cache=None,
+        cache_params: Optional[FalconMambaCache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ):
+        # Overwritten -- uses `cache_params` as opposed to `past_key_values`
+        model_inputs = {"input_ids": input_ids.contiguous()}
+        if use_cache and cache_params is None:
+            # we initialize the `cache_position` to full size of `conv_states` at prefill stage
+            # considering padding will be applied when input length is shorter, and truncation
+            # will be applied when it is longer, so it will be equivalent to always have it match
+            # the length of `cache_params.conv_states`, which is `config.conv_kernel`
+            cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device)
+            if inputs_embeds is not None:
+                model_inputs = {"inputs_embeds": inputs_embeds}
+                max_batch_size = inputs_embeds.size(0)
+            else:
+                max_batch_size = input_ids.size(0)
+            cache_params = FalconMambaCache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype)
+
+        if use_cache and cache_position[0] > 0:
+            model_inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1).contiguous()
+            attention_mask = None
+
+        if not use_cache and inputs_embeds is not None:
+            model_inputs = {"inputs_embeds": inputs_embeds}
+
+        model_inputs.update(
+            {
+                "cache_params": cache_params,
+                "use_cache": use_cache,
+                "cache_position": cache_position,
+                "attention_mask": attention_mask,
+            }
+        )
+
+        # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
+        for key, value in kwargs.items():
+            if key not in model_inputs:
+                model_inputs[key] = value
+
+        return model_inputs
+
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        cache_params: Optional[FalconMambaCache] = None,
+        labels: Optional[torch.LongTensor] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        use_cache: Optional[bool] = None,
+        cache_position: Optional[torch.Tensor] = None,
+        **kwargs,  # for now we need this for generation
+    ) -> Union[tuple, FalconMambaCausalLMOutput]:
+        r"""
+        cache_params (`FalconMambaCache`, *optional*):
+            If passed along, the model uses the previous state in all the blocks (which will give the output for the
+            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+        use_cache (`bool`, *optional*):
+            If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        falcon_mamba_outputs = self.backbone(
+            input_ids,
+            cache_params=cache_params,
+            inputs_embeds=inputs_embeds,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            use_cache=use_cache,
+            cache_position=cache_position,
+            attention_mask=attention_mask,
+        )
+        hidden_states = falcon_mamba_outputs[0]
+
+        logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
+
+        loss = None
+        if labels is not None:
+            # move labels to correct device to enable model parallelism
+            labels = labels.to(logits.device)
+            # Shift so that tokens < n predict n
+            shift_logits = logits[..., :-1, :].contiguous()
+            shift_labels = labels[..., 1:].contiguous()
+            # Flatten the tokens
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+        if not return_dict:
+            output = (logits,) + falcon_mamba_outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return FalconMambaCausalLMOutput(
+            loss=loss,
+            logits=logits,
+            cache_params=falcon_mamba_outputs.cache_params,
+            hidden_states=falcon_mamba_outputs.hidden_states,
+        )
+
+
+__all__ = ["FalconMambaForCausalLM", "FalconMambaModel", "FalconMambaPreTrainedModel", "FalconMambaCache"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/modular_falcon_mamba.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/modular_falcon_mamba.py
new file mode 100644
index 0000000000000000000000000000000000000000..6df2be3a2652cf47100b82aa69be3c2554ba4161
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/falcon_mamba/modular_falcon_mamba.py
@@ -0,0 +1,582 @@
+# coding=utf-8
+# Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch FALCONMAMBA model."""
+
+from typing import Optional
+
+import torch
+from torch import nn
+
+from ...utils import auto_docstring, logging
+from ...utils.import_utils import (
+    is_mamba_ssm_available,
+    is_mambapy_available,
+)
+from ..mamba.configuration_mamba import MambaConfig
+from ..mamba.modeling_mamba import (
+    MambaBlock,
+    MambaCache,
+    MambaCausalLMOutput,
+    MambaForCausalLM,
+    MambaMixer,
+    MambaModel,
+    MambaOutput,
+    MambaPreTrainedModel,
+    MambaRMSNorm,
+    _lazy_load_causal_conv1d,
+)
+
+
+logger = logging.get_logger(__name__)
+
+if is_mambapy_available():
+    from mambapy.pscan import pscan
+else:
+    pscan = None
+
+if is_mamba_ssm_available():
+    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
+    from mamba_ssm.ops.triton.selective_state_update import selective_state_update
+
+    from ...kernels.falcon_mamba import mamba_inner_fn
+else:
+    selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
+
+_causal_conv1d_cache = None
+
+
+class FalconMambaConfig(MambaConfig):
+    """
+    This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA
+    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+    defaults will yield a similar configuration to that of the FALCON_MAMBA
+    [tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 50280):
+            Vocabulary size of the FALCON_MAMBA model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`FalconMambaModel`].
+        hidden_size (`int`, *optional*, defaults to 768):
+            Dimensionality of the embeddings and hidden states.
+        state_size (`int`, *optional*, defaults to 16): shape of the state space latents.
+        num_hidden_layers (`int`, *optional*, defaults to 32):
+            Number of hidden layers in the model.
+        layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
+            The epsilon to use in the layer normalization layers.
+        pad_token_id (`int`, *optional*, defaults to 0):
+            Padding token id.
+        bos_token_id (`int`, *optional*, defaults to 0):
+            The id of the beginning of sentence token in the vocabulary.
+        eos_token_id (`int`, *optional*, defaults to 0):
+            The id of the end of sentence token in the vocabulary.
+        expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
+        conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
+        use_bias (`bool`, *optional*, defaults to `False`):
+            Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
+        use_conv_bias (`bool`, *optional*, defaults to `True`):
+            Whether or not to use bias in the convolution layer of the mixer block.
+        hidden_act (`str`, *optional*, defaults to `"silu"`):
+            The non-linear activation function (function or string) in the decoder.
+        initializer_range (`float`, *optional*, defaults to 0.1):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        residual_in_fp32 (`bool`, *optional*, defaults to `True`):
+            Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
+        time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
+            Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
+        time_step_scale (`float`, *optional*, defaults to 1.0):
+            Scale used used to scale `dt_proj.bias`.
+        time_step_min (`float`, *optional*, defaults to 0.001):
+            Minimum `time_step` used to bound `dt_proj.bias`.
+        time_step_max (`float`, *optional*, defaults to 0.1):
+            Maximum `time_step` used to bound `dt_proj.bias`.
+        time_step_init_scheme (`float`, *optional*, defaults to `"random"`):
+            Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`
+        time_step_floor (`float`, *optional*, defaults to 0.0001):
+            Minimum clamping value of the `dt_proj.bias` layer initialization.
+        rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
+            Whether or not to rescale `out_proj` weights when initializing.
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the cache should be used.
+        use_falcon_mambapy (`bool`, *optional*, defaults to `False`):
+            This argument corresponds to `use_mambapy` in MambaConfig.
+            Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not available. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
+        mixer_rms_eps (`float`, *optional*, defaults to 1e-06):
+            The RMS norm epsilon value that is used in the Mixer RMS norm for B, C and dt states.
+
+
+    Example:
+
+    ```python
+    >>> from transformers import FalconMambaConfig, FalconMambaModel
+
+    >>> # Initializing a FalconMamba configuration
+    >>> configuration = FalconMambaConfig()
+
+    >>> # Initializing a model (with random weights) from the configuration
+    >>> model = FalconMambaModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    def __init__(
+        self,
+        vocab_size=50280,
+        hidden_size=768,
+        state_size=16,
+        num_hidden_layers=32,
+        layer_norm_epsilon=1e-5,
+        pad_token_id=0,
+        bos_token_id=0,
+        eos_token_id=0,
+        expand=2,
+        conv_kernel=4,
+        use_bias=False,
+        use_conv_bias=True,
+        hidden_act="silu",
+        initializer_range=0.1,
+        residual_in_fp32=True,
+        time_step_rank="auto",
+        time_step_scale=1.0,
+        time_step_min=0.001,
+        time_step_max=0.1,
+        time_step_init_scheme="random",
+        time_step_floor=1e-4,
+        rescale_prenorm_residual=False,
+        use_cache=True,
+        use_falcon_mambapy=False,
+        mixer_rms_eps=1e-6,
+        **kwargs,
+    ):
+        super().__init__(
+            vocab_size=vocab_size,
+            hidden_size=hidden_size,
+            state_size=state_size,
+            num_hidden_layers=num_hidden_layers,
+            layer_norm_epsilon=layer_norm_epsilon,
+            pad_token_id=pad_token_id,
+            bos_token_id=bos_token_id,
+            eos_token_id=eos_token_id,
+            expand=expand,
+            conv_kernel=conv_kernel,
+            use_bias=use_bias,
+            use_conv_bias=use_conv_bias,
+            hidden_act=hidden_act,
+            initializer_range=initializer_range,
+            residual_in_fp32=residual_in_fp32,
+            time_step_rank=time_step_rank,
+            time_step_scale=time_step_scale,
+            time_step_min=time_step_min,
+            time_step_max=time_step_max,
+            time_step_init_scheme=time_step_init_scheme,
+            time_step_floor=time_step_floor,
+            rescale_prenorm_residual=rescale_prenorm_residual,
+            use_cache=use_cache,
+            use_falcon_mambapy=use_falcon_mambapy,
+            **kwargs,
+        )
+        self.mixer_rms_eps = mixer_rms_eps
+        # This is needed since mamba overrides the intermediate_size attribute
+        self.intermediate_size = (
+            int(expand * self.hidden_size)
+            if kwargs.get("intermediate_size") is None
+            else kwargs.get("intermediate_size")
+        )
+
+
+class FalconMambaCache(MambaCache):
+    """
+    Cache for falcon_mamba model which does not have attention mechanism and key value states.
+
+    Arguments:
+        config (`PretrainedConfig):
+            The configuration file defining the shape-related attributes required to initialize the static cache.
+        max_batch_size (`int`):
+            The maximum batch size with which the model will be used. Note that a new instance must be instantiated if
+            a smaller batch size is used.
+        dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
+            The default `dtype` to use when initializing the layer.
+        device (`torch.device` or `str`, *optional*):
+            The device on which the cache should be initialized. Should be the same as the layer.
+
+    Example:
+
+        ```python
+        >>> import torch
+        >>> from transformers import AutoTokenizer, FalconMambaForCausalLM, FalconMambaCache
+
+        >>> model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b")
+        >>> tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b")
+
+        >>> inputs = tokenizer(text="My name is FalconMamba", return_tensors="pt")
+
+        >>> # Prepare a cache class and pass it to model's forward
+        >>> cache_params = FalconMambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
+        >>> cache_position = torch.arange(len(inputs["input_ids"][0]), device=model.device)  # sequence length
+        >>> outputs = model(**inputs, cache_params=cache_params, cache_position=cache_position, use_cache=True)
+        >>> outputs.cache_params
+        ```
+    """
+
+    pass
+
+
+def rms_forward(hidden_states, variance_epsilon=1e-6):
+    """
+    Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
+    leverage this in order to multiply the final result with the RMSNorm weight
+
+    Args:
+        hidden_states (`torch.Tensor`):
+            Hidden states to normalize
+        variance_epsilon (`float`):
+            The eps value to add in the square root scaling factor
+    """
+    input_dtype = hidden_states.dtype
+    hidden_states = hidden_states.to(torch.float32)
+
+    variance = hidden_states.pow(2).mean(-1, keepdim=True)
+    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
+    return hidden_states.to(input_dtype)
+
+
+class FalconMambaMixer(MambaMixer):
+    def warn_slow_implementation(self):
+        causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
+        is_fast_path_available = all(
+            (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
+        )
+        if not is_fast_path_available:
+            if self.use_falcon_mambapy:
+                if is_mambapy_available():
+                    logger.warning_once(
+                        "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
+                        " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
+                        " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d"
+                    )
+                else:
+                    raise ImportError(
+                        "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py."
+                    )
+            else:
+                logger.warning_once(
+                    "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
+                    " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
+                    " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
+                )
+
+    def __init__(self, config: FalconMambaConfig, layer_idx: int):
+        super().__init__(config, layer_idx)
+        # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
+        self.register_buffer(
+            "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
+        )
+        self.register_buffer(
+            "dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False
+        )
+        self.rms_eps = config.mixer_rms_eps
+
+    def cuda_kernels_forward(
+        self,
+        hidden_states: torch.Tensor,
+        cache_params: Optional[FalconMambaCache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+    ):
+        # 1. Gated MLP's linear projection
+        projected_states = self.in_proj(hidden_states).transpose(1, 2)
+
+        if self.training and cache_params is None:  # Doesn't support outputting the states -> used for training
+            contextualized_states = mamba_inner_fn(
+                projected_states,
+                self.conv1d.weight,
+                self.conv1d.bias if self.use_conv_bias else None,
+                self.x_proj.weight,
+                self.dt_proj.weight,
+                self.out_proj.weight,
+                self.out_proj.bias.float() if self.use_bias else None,
+                -torch.exp(self.A_log.float()),
+                None,  # input-dependent B
+                None,  # input-dependent C
+                self.D.float(),
+                delta_bias=self.dt_proj.bias.float(),
+                delta_softplus=True,
+                b_rms_weight=self.b_c_rms,
+                c_rms_weight=self.b_c_rms,
+                dt_rms_weight=self.dt_rms,
+                b_c_dt_rms_eps=self.rms_eps,
+            )
+
+        else:
+            causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
+            hidden_states, gate = projected_states.chunk(2, dim=1)
+
+            if attention_mask is not None:
+                hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+            # 2. Convolution sequence transformation
+            conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
+            if cache_params is not None and cache_position[0] > 0:
+                hidden_states = causal_conv1d_update(
+                    hidden_states.squeeze(-1),
+                    cache_params.conv_states[self.layer_idx],
+                    conv_weights,
+                    self.conv1d.bias,
+                    self.activation,
+                )
+                hidden_states = hidden_states.unsqueeze(-1)
+            else:
+                if cache_params is not None:
+                    conv_states = nn.functional.pad(
+                        hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
+                    )
+                    cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
+                hidden_states = causal_conv1d_fn(
+                    hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
+                )
+
+            if attention_mask is not None:
+                hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+            # 3. State Space Model sequence transformation
+            # 3.a. input varying initialization of time_step, B and C
+            ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
+            time_step, B, C = torch.split(
+                ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
+            )
+
+            B = rms_forward(B, variance_epsilon=self.rms_eps)
+            C = rms_forward(C, variance_epsilon=self.rms_eps)
+            time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
+
+            # In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
+            # at the price of a small overhead.
+            if hasattr(self.config, "_pre_quantization_dtype"):
+                discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
+            else:
+                discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
+
+            A = -torch.exp(self.A_log.float())
+            # 3.c perform the recurrence y ← SSM(A, B, C)(x)
+            time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
+            if cache_params is not None and cache_position[0] > 0:
+                scan_outputs = selective_state_update(
+                    cache_params.ssm_states[self.layer_idx],
+                    hidden_states[..., 0],
+                    discrete_time_step[..., 0],
+                    A,
+                    B[:, 0],
+                    C[:, 0],
+                    self.D,
+                    gate[..., 0],
+                    time_proj_bias,
+                    dt_softplus=True,
+                ).unsqueeze(-1)
+            else:
+                scan_outputs, ssm_state = selective_scan_fn(
+                    hidden_states,
+                    discrete_time_step,
+                    A,
+                    B.transpose(1, 2),
+                    C.transpose(1, 2),
+                    self.D.float(),
+                    gate,
+                    time_proj_bias,
+                    delta_softplus=True,
+                    return_last_state=True,
+                )
+                if ssm_state is not None and cache_params is not None:
+                    cache_params.update_ssm_state(self.layer_idx, ssm_state)
+
+            # 4. Final linear projection
+            contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
+        return contextualized_states
+
+    def slow_forward(
+        self,
+        input_states,
+        cache_params: Optional[FalconMambaCache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+    ):
+        batch_size, seq_len, _ = input_states.shape
+        dtype = input_states.dtype
+        # 1. Gated MLP's linear projection
+        projected_states = self.in_proj(input_states).transpose(1, 2)  # [batch, 2 * intermediate_size, seq_len]
+        hidden_states, gate = projected_states.chunk(2, dim=1)
+
+        if attention_mask is not None:
+            hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+        # 2. Convolution sequence transformation
+        if cache_params is not None:
+            ssm_state = cache_params.ssm_states[self.layer_idx].clone()
+            ssm_state = ssm_state.to(hidden_states.device)
+            # use `cache_position.shape[0]` to check whether we are in prefill
+            # stage, it's equivalent to check `cache_position[0] == 0`, which
+            # breaks dynamo fullgraph constraints
+            if cache_position is not None and cache_position.shape[0] == self.conv_kernel_size:
+                conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
+
+                cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
+                hidden_states = self.act(
+                    self.conv1d(hidden_states)[..., :seq_len]
+                )  # [batch, intermediate_size, seq_len]
+            else:
+                conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
+                conv_state = conv_state.to(self.conv1d.weight.device)
+                hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
+                if self.use_conv_bias:
+                    hidden_states += self.conv1d.bias
+                hidden_states = (
+                    self.act(hidden_states).to(dtype).unsqueeze(-1)
+                )  # [batch, intermediate_size, 1] : decoding
+        else:
+            ssm_state = torch.zeros(
+                (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype
+            )
+            hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])  # [batch, intermediate_size, seq_len]
+
+        if attention_mask is not None:
+            hidden_states = hidden_states * attention_mask.unsqueeze(1)
+
+        # 3. State Space Model sequence transformation
+        # 3.a. Selection:  [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
+        ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
+        time_step, B, C = torch.split(
+            ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
+        )
+
+        B = rms_forward(B, variance_epsilon=self.rms_eps)
+        C = rms_forward(C, variance_epsilon=self.rms_eps)
+        time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
+
+        discrete_time_step = self.dt_proj(time_step)  # [batch, seq_len, intermediate_size]
+        discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(
+            1, 2
+        )  # [batch, intermediate_size, seq_len]
+
+        # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
+        A = -torch.exp(self.A_log.float())  # [intermediate_size, ssm_state_size]
+        discrete_A = torch.exp(
+            A[None, :, None, :] * discrete_time_step[:, :, :, None]
+        )  # [batch, intermediate_size, seq_len, ssm_state_size]
+        discrete_B = (
+            discrete_time_step[:, :, :, None] * B[:, None, :, :].float()
+        )  # [batch, intermediate_size, seq_len, ssm_state_size]
+        deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
+
+        # 3.c perform the recurrence y ← SSM(A, B, C)(x)
+        if self.use_falcon_mambapy and self.training and cache_params is None:
+            hs = pscan(
+                discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)
+            )  # [batch, seq_len, intermediate_size, ssm_state_size]
+            scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2)  # [batch, intermediate_size, seq_len]
+            scan_output = scan_output + hidden_states * self.D[None, :, None]
+            scan_output = scan_output * self.act(gate)
+        else:
+            scan_outputs = []
+            for i in range(seq_len):
+                ssm_state = (
+                    discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
+                )  # [batch, intermediate_size, ssm_state]
+                scan_output = torch.matmul(
+                    ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)
+                )  # [batch, intermediate_size, 1]
+                scan_outputs.append(scan_output[:, :, 0])
+            scan_output = torch.stack(scan_outputs, dim=-1)  # [batch, intermediate_size, seq_len]
+            scan_output = scan_output + (hidden_states * self.D[None, :, None])
+            scan_output = scan_output * self.act(gate)
+
+            if cache_params is not None:
+                cache_params.update_ssm_state(self.layer_idx, ssm_state)
+
+        # 4. Final linear projection
+        contextualized_states = self.out_proj(scan_output.transpose(1, 2))  # [batch, seq_len, hidden_size]
+        return contextualized_states
+
+    def forward(
+        self,
+        hidden_states,
+        cache_params: Optional[FalconMambaCache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+    ):
+        causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
+        is_fast_path_available = all(
+            (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
+        )
+        if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
+            return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
+        return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
+
+
+class FalconMambaRMSNorm(MambaRMSNorm):
+    def forward(self, hidden_states):
+        return self.weight.to(hidden_states.device) * rms_forward(
+            hidden_states, variance_epsilon=self.variance_epsilon
+        )
+
+
+class FalconMambaBlock(MambaBlock):
+    pass
+
+
+@auto_docstring
+class FalconMambaPreTrainedModel(MambaPreTrainedModel):
+    pass
+
+
+class FalconMambaOutput(MambaOutput):
+    pass
+
+
+class FalconMambaCausalLMOutput(MambaCausalLMOutput):
+    pass
+
+
+class FalconMambaModel(MambaModel, FalconMambaPreTrainedModel):
+    def __init__(self, config):
+        FalconMambaPreTrainedModel.__init__(self, config)
+
+        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
+        self.layers = nn.ModuleList(
+            [FalconMambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]
+        )
+
+        self.gradient_checkpointing = False
+        self.norm_f = FalconMambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def load_hook(self, state_dict, prefix, *args):
+        raise AttributeError("Not needed for FalconMamba")
+
+
+class FalconMambaForCausalLM(MambaForCausalLM):
+    pass
+
+
+__all__ = [
+    "FalconMambaForCausalLM",
+    "FalconMambaModel",
+    "FalconMambaPreTrainedModel",
+    "FalconMambaCache",
+    "FalconMambaConfig",
+]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..44d1ec7236310774ed6b1379683c144d7f93ecce
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+    from .configuration_fastspeech2_conformer import *
+    from .modeling_fastspeech2_conformer import *
+    from .tokenization_fastspeech2_conformer import *
+else:
+    import sys
+
+    _file = globals()["__file__"]
+    sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..89d65a261c64fbabe493aa37677aaacb6f226a3b
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py
@@ -0,0 +1,480 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""FastSpeech2Conformer model configuration"""
+
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class FastSpeech2ConformerConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`FastSpeech2ConformerModel`]. It is used to
+    instantiate a FastSpeech2Conformer model according to the specified arguments, defining the model architecture.
+    Instantiating a configuration with the defaults will yield a similar configuration to that of the
+    FastSpeech2Conformer [espnet/fastspeech2_conformer](https://huggingface.co/espnet/fastspeech2_conformer)
+    architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        hidden_size (`int`, *optional*, defaults to 384):
+            The dimensionality of the hidden layers.
+        vocab_size (`int`, *optional*, defaults to 78):
+            The size of the vocabulary.
+        num_mel_bins (`int`, *optional*, defaults to 80):
+            The number of mel filters used in the filter bank.
+        encoder_num_attention_heads (`int`, *optional*, defaults to 2):
+            The number of attention heads in the encoder.
+        encoder_layers (`int`, *optional*, defaults to 4):
+            The number of layers in the encoder.
+        encoder_linear_units (`int`, *optional*, defaults to 1536):
+            The number of units in the linear layer of the encoder.
+        decoder_layers (`int`, *optional*, defaults to 4):
+            The number of layers in the decoder.
+        decoder_num_attention_heads (`int`, *optional*, defaults to 2):
+            The number of attention heads in the decoder.
+        decoder_linear_units (`int`, *optional*, defaults to 1536):
+            The number of units in the linear layer of the decoder.
+        speech_decoder_postnet_layers (`int`, *optional*, defaults to 5):
+            The number of layers in the post-net of the speech decoder.
+        speech_decoder_postnet_units (`int`, *optional*, defaults to 256):
+            The number of units in the post-net layers of the speech decoder.
+        speech_decoder_postnet_kernel (`int`, *optional*, defaults to 5):
+            The kernel size in the post-net of the speech decoder.
+        positionwise_conv_kernel_size (`int`, *optional*, defaults to 3):
+            The size of the convolution kernel used in the position-wise layer.
+        encoder_normalize_before (`bool`, *optional*, defaults to `False`):
+            Specifies whether to normalize before encoder layers.
+        decoder_normalize_before (`bool`, *optional*, defaults to `False`):
+            Specifies whether to normalize before decoder layers.
+        encoder_concat_after (`bool`, *optional*, defaults to `False`):
+            Specifies whether to concatenate after encoder layers.
+        decoder_concat_after (`bool`, *optional*, defaults to `False`):
+            Specifies whether to concatenate after decoder layers.
+        reduction_factor (`int`, *optional*, defaults to 1):
+            The factor by which the speech frame rate is reduced.
+        speaking_speed (`float`, *optional*, defaults to 1.0):
+            The speed of the speech produced.
+        use_macaron_style_in_conformer (`bool`, *optional*, defaults to `True`):
+            Specifies whether to use macaron style in the conformer.
+        use_cnn_in_conformer (`bool`, *optional*, defaults to `True`):
+            Specifies whether to use convolutional neural networks in the conformer.
+        encoder_kernel_size (`int`, *optional*, defaults to 7):
+            The kernel size used in the encoder.
+        decoder_kernel_size (`int`, *optional*, defaults to 31):
+            The kernel size used in the decoder.
+        duration_predictor_layers (`int`, *optional*, defaults to 2):
+            The number of layers in the duration predictor.
+        duration_predictor_channels (`int`, *optional*, defaults to 256):
+            The number of channels in the duration predictor.
+        duration_predictor_kernel_size (`int`, *optional*, defaults to 3):
+            The kernel size used in the duration predictor.
+        energy_predictor_layers (`int`, *optional*, defaults to 2):
+            The number of layers in the energy predictor.
+        energy_predictor_channels (`int`, *optional*, defaults to 256):
+            The number of channels in the energy predictor.
+        energy_predictor_kernel_size (`int`, *optional*, defaults to 3):
+            The kernel size used in the energy predictor.
+        energy_predictor_dropout (`float`, *optional*, defaults to 0.5):
+            The dropout rate in the energy predictor.
+        energy_embed_kernel_size (`int`, *optional*, defaults to 1):
+            The kernel size used in the energy embed layer.
+        energy_embed_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout rate in the energy embed layer.
+        stop_gradient_from_energy_predictor (`bool`, *optional*, defaults to `False`):
+            Specifies whether to stop gradients from the energy predictor.
+        pitch_predictor_layers (`int`, *optional*, defaults to 5):
+            The number of layers in the pitch predictor.
+        pitch_predictor_channels (`int`, *optional*, defaults to 256):
+            The number of channels in the pitch predictor.
+        pitch_predictor_kernel_size (`int`, *optional*, defaults to 5):
+            The kernel size used in the pitch predictor.
+        pitch_predictor_dropout (`float`, *optional*, defaults to 0.5):
+            The dropout rate in the pitch predictor.
+        pitch_embed_kernel_size (`int`, *optional*, defaults to 1):
+            The kernel size used in the pitch embed layer.
+        pitch_embed_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout rate in the pitch embed layer.
+        stop_gradient_from_pitch_predictor (`bool`, *optional*, defaults to `True`):
+            Specifies whether to stop gradients from the pitch predictor.
+        encoder_dropout_rate (`float`, *optional*, defaults to 0.2):
+            The dropout rate in the encoder.
+        encoder_positional_dropout_rate (`float`, *optional*, defaults to 0.2):
+            The positional dropout rate in the encoder.
+        encoder_attention_dropout_rate (`float`, *optional*, defaults to 0.2):
+            The attention dropout rate in the encoder.
+        decoder_dropout_rate (`float`, *optional*, defaults to 0.2):
+            The dropout rate in the decoder.
+        decoder_positional_dropout_rate (`float`, *optional*, defaults to 0.2):
+            The positional dropout rate in the decoder.
+        decoder_attention_dropout_rate (`float`, *optional*, defaults to 0.2):
+            The attention dropout rate in the decoder.
+        duration_predictor_dropout_rate (`float`, *optional*, defaults to 0.2):
+            The dropout rate in the duration predictor.
+        speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5):
+            The dropout rate in the speech decoder postnet.
+        max_source_positions (`int`, *optional*, defaults to 5000):
+            if `"relative"` position embeddings are used, defines the maximum source input positions.
+        use_masking (`bool`, *optional*, defaults to `True`):
+            Specifies whether to use masking in the model.
+        use_weighted_masking (`bool`, *optional*, defaults to `False`):
+            Specifies whether to use weighted masking in the model.
+        num_speakers (`int`, *optional*):
+            Number of speakers. If set to > 1, assume that the speaker ids will be provided as the input and use
+            speaker id embedding layer.
+        num_languages (`int`, *optional*):
+            Number of languages. If set to > 1, assume that the language ids will be provided as the input and use the
+            language id embedding layer.
+        speaker_embed_dim (`int`, *optional*):
+            Speaker embedding dimension. If set to > 0, assume that speaker_embedding will be provided as the input.
+        is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+            Specifies whether the model is an encoder-decoder.
+
+    Example:
+
+    ```python
+    >>> from transformers import FastSpeech2ConformerModel, FastSpeech2ConformerConfig
+
+    >>> # Initializing a FastSpeech2Conformer style configuration
+    >>> configuration = FastSpeech2ConformerConfig()
+
+    >>> # Initializing a model from the FastSpeech2Conformer style configuration
+    >>> model = FastSpeech2ConformerModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "fastspeech2_conformer"
+    base_config_key = "model_config"
+    attribute_map = {"num_hidden_layers": "encoder_layers", "num_attention_heads": "encoder_num_attention_heads"}
+
+    def __init__(
+        self,
+        hidden_size=384,
+        vocab_size=78,
+        num_mel_bins=80,
+        encoder_num_attention_heads=2,
+        encoder_layers=4,
+        encoder_linear_units=1536,
+        decoder_layers=4,
+        decoder_num_attention_heads=2,
+        decoder_linear_units=1536,
+        speech_decoder_postnet_layers=5,
+        speech_decoder_postnet_units=256,
+        speech_decoder_postnet_kernel=5,
+        positionwise_conv_kernel_size=3,
+        encoder_normalize_before=False,
+        decoder_normalize_before=False,
+        encoder_concat_after=False,
+        decoder_concat_after=False,
+        reduction_factor=1,
+        speaking_speed=1.0,
+        use_macaron_style_in_conformer=True,
+        use_cnn_in_conformer=True,
+        encoder_kernel_size=7,
+        decoder_kernel_size=31,
+        duration_predictor_layers=2,
+        duration_predictor_channels=256,
+        duration_predictor_kernel_size=3,
+        energy_predictor_layers=2,
+        energy_predictor_channels=256,
+        energy_predictor_kernel_size=3,
+        energy_predictor_dropout=0.5,
+        energy_embed_kernel_size=1,
+        energy_embed_dropout=0.0,
+        stop_gradient_from_energy_predictor=False,
+        pitch_predictor_layers=5,
+        pitch_predictor_channels=256,
+        pitch_predictor_kernel_size=5,
+        pitch_predictor_dropout=0.5,
+        pitch_embed_kernel_size=1,
+        pitch_embed_dropout=0.0,
+        stop_gradient_from_pitch_predictor=True,
+        encoder_dropout_rate=0.2,
+        encoder_positional_dropout_rate=0.2,
+        encoder_attention_dropout_rate=0.2,
+        decoder_dropout_rate=0.2,
+        decoder_positional_dropout_rate=0.2,
+        decoder_attention_dropout_rate=0.2,
+        duration_predictor_dropout_rate=0.2,
+        speech_decoder_postnet_dropout=0.5,
+        max_source_positions=5000,
+        use_masking=True,
+        use_weighted_masking=False,
+        num_speakers=None,
+        num_languages=None,
+        speaker_embed_dim=None,
+        is_encoder_decoder=True,
+        **kwargs,
+    ):
+        if positionwise_conv_kernel_size % 2 == 0:
+            raise ValueError(
+                f"positionwise_conv_kernel_size must be odd, but got {positionwise_conv_kernel_size} instead."
+            )
+        if encoder_kernel_size % 2 == 0:
+            raise ValueError(f"encoder_kernel_size must be odd, but got {encoder_kernel_size} instead.")
+        if decoder_kernel_size % 2 == 0:
+            raise ValueError(f"decoder_kernel_size must be odd, but got {decoder_kernel_size} instead.")
+        if duration_predictor_kernel_size % 2 == 0:
+            raise ValueError(
+                f"duration_predictor_kernel_size must be odd, but got {duration_predictor_kernel_size} instead."
+            )
+        if energy_predictor_kernel_size % 2 == 0:
+            raise ValueError(
+                f"energy_predictor_kernel_size must be odd, but got {energy_predictor_kernel_size} instead."
+            )
+        if energy_embed_kernel_size % 2 == 0:
+            raise ValueError(f"energy_embed_kernel_size must be odd, but got {energy_embed_kernel_size} instead.")
+        if pitch_predictor_kernel_size % 2 == 0:
+            raise ValueError(
+                f"pitch_predictor_kernel_size must be odd, but got {pitch_predictor_kernel_size} instead."
+            )
+        if pitch_embed_kernel_size % 2 == 0:
+            raise ValueError(f"pitch_embed_kernel_size must be odd, but got {pitch_embed_kernel_size} instead.")
+        if hidden_size % encoder_num_attention_heads != 0:
+            raise ValueError("The hidden_size must be evenly divisible by encoder_num_attention_heads.")
+        if hidden_size % decoder_num_attention_heads != 0:
+            raise ValueError("The hidden_size must be evenly divisible by decoder_num_attention_heads.")
+        if use_masking and use_weighted_masking:
+            raise ValueError("Either use_masking or use_weighted_masking can be True, but not both.")
+
+        self.hidden_size = hidden_size
+        self.vocab_size = vocab_size
+        self.num_mel_bins = num_mel_bins
+        self.encoder_config = {
+            "num_attention_heads": encoder_num_attention_heads,
+            "layers": encoder_layers,
+            "kernel_size": encoder_kernel_size,
+            "attention_dropout_rate": encoder_attention_dropout_rate,
+            "dropout_rate": encoder_dropout_rate,
+            "positional_dropout_rate": encoder_positional_dropout_rate,
+            "linear_units": encoder_linear_units,
+            "normalize_before": encoder_normalize_before,
+            "concat_after": encoder_concat_after,
+        }
+        self.decoder_config = {
+            "num_attention_heads": decoder_num_attention_heads,
+            "layers": decoder_layers,
+            "kernel_size": decoder_kernel_size,
+            "attention_dropout_rate": decoder_attention_dropout_rate,
+            "dropout_rate": decoder_dropout_rate,
+            "positional_dropout_rate": decoder_positional_dropout_rate,
+            "linear_units": decoder_linear_units,
+            "normalize_before": decoder_normalize_before,
+            "concat_after": decoder_concat_after,
+        }
+        self.encoder_num_attention_heads = encoder_num_attention_heads
+        self.encoder_layers = encoder_layers
+        self.duration_predictor_channels = duration_predictor_channels
+        self.duration_predictor_kernel_size = duration_predictor_kernel_size
+        self.duration_predictor_layers = duration_predictor_layers
+        self.energy_embed_dropout = energy_embed_dropout
+        self.energy_embed_kernel_size = energy_embed_kernel_size
+        self.energy_predictor_channels = energy_predictor_channels
+        self.energy_predictor_dropout = energy_predictor_dropout
+        self.energy_predictor_kernel_size = energy_predictor_kernel_size
+        self.energy_predictor_layers = energy_predictor_layers
+        self.pitch_embed_dropout = pitch_embed_dropout
+        self.pitch_embed_kernel_size = pitch_embed_kernel_size
+        self.pitch_predictor_channels = pitch_predictor_channels
+        self.pitch_predictor_dropout = pitch_predictor_dropout
+        self.pitch_predictor_kernel_size = pitch_predictor_kernel_size
+        self.pitch_predictor_layers = pitch_predictor_layers
+        self.positionwise_conv_kernel_size = positionwise_conv_kernel_size
+        self.speech_decoder_postnet_units = speech_decoder_postnet_units
+        self.speech_decoder_postnet_dropout = speech_decoder_postnet_dropout
+        self.speech_decoder_postnet_kernel = speech_decoder_postnet_kernel
+        self.speech_decoder_postnet_layers = speech_decoder_postnet_layers
+        self.reduction_factor = reduction_factor
+        self.speaking_speed = speaking_speed
+        self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
+        self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor
+        self.max_source_positions = max_source_positions
+        self.use_cnn_in_conformer = use_cnn_in_conformer
+        self.use_macaron_style_in_conformer = use_macaron_style_in_conformer
+        self.use_masking = use_masking
+        self.use_weighted_masking = use_weighted_masking
+        self.num_speakers = num_speakers
+        self.num_languages = num_languages
+        self.speaker_embed_dim = speaker_embed_dim
+        self.duration_predictor_dropout_rate = duration_predictor_dropout_rate
+        self.is_encoder_decoder = is_encoder_decoder
+
+        super().__init__(
+            is_encoder_decoder=is_encoder_decoder,
+            **kwargs,
+        )
+
+
+class FastSpeech2ConformerHifiGanConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`FastSpeech2ConformerHifiGanModel`]. It is used to
+    instantiate a FastSpeech2Conformer HiFi-GAN vocoder model according to the specified arguments, defining the model
+    architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the
+    FastSpeech2Conformer
+    [espnet/fastspeech2_conformer_hifigan](https://huggingface.co/espnet/fastspeech2_conformer_hifigan) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        model_in_dim (`int`, *optional*, defaults to 80):
+            The number of frequency bins in the input log-mel spectrogram.
+        upsample_initial_channel (`int`, *optional*, defaults to 512):
+            The number of input channels into the upsampling network.
+        upsample_rates (`tuple[int]` or `list[int]`, *optional*, defaults to `[8, 8, 2, 2]`):
+            A tuple of integers defining the stride of each 1D convolutional layer in the upsampling network. The
+            length of *upsample_rates* defines the number of convolutional layers and has to match the length of
+            *upsample_kernel_sizes*.
+        upsample_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[16, 16, 4, 4]`):
+            A tuple of integers defining the kernel size of each 1D convolutional layer in the upsampling network. The
+            length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match the length of
+            *upsample_rates*.
+        resblock_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[3, 7, 11]`):
+            A tuple of integers defining the kernel sizes of the 1D convolutional layers in the multi-receptive field
+            fusion (MRF) module.
+        resblock_dilation_sizes (`tuple[tuple[int]]` or `list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
+            A nested tuple of integers defining the dilation rates of the dilated 1D convolutional layers in the
+            multi-receptive field fusion (MRF) module.
+        initializer_range (`float`, *optional*, defaults to 0.01):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        leaky_relu_slope (`float`, *optional*, defaults to 0.1):
+            The angle of the negative slope used by the leaky ReLU activation.
+        normalize_before (`bool`, *optional*, defaults to `True`):
+            Whether or not to normalize the spectrogram before vocoding using the vocoder's learned mean and variance.
+
+    Example:
+
+    ```python
+    >>> from transformers import FastSpeech2ConformerHifiGan, FastSpeech2ConformerHifiGanConfig
+
+    >>> # Initializing a FastSpeech2ConformerHifiGan configuration
+    >>> configuration = FastSpeech2ConformerHifiGanConfig()
+
+    >>> # Initializing a model (with random weights) from the configuration
+    >>> model = FastSpeech2ConformerHifiGan(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "hifigan"
+    base_config_key = "vocoder_config"
+
+    def __init__(
+        self,
+        model_in_dim=80,
+        upsample_initial_channel=512,
+        upsample_rates=[8, 8, 2, 2],
+        upsample_kernel_sizes=[16, 16, 4, 4],
+        resblock_kernel_sizes=[3, 7, 11],
+        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+        initializer_range=0.01,
+        leaky_relu_slope=0.1,
+        normalize_before=True,
+        **kwargs,
+    ):
+        self.model_in_dim = model_in_dim
+        self.upsample_initial_channel = upsample_initial_channel
+        self.upsample_rates = upsample_rates
+        self.upsample_kernel_sizes = upsample_kernel_sizes
+        self.resblock_kernel_sizes = resblock_kernel_sizes
+        self.resblock_dilation_sizes = resblock_dilation_sizes
+        self.initializer_range = initializer_range
+        self.leaky_relu_slope = leaky_relu_slope
+        self.normalize_before = normalize_before
+        super().__init__(**kwargs)
+
+
+class FastSpeech2ConformerWithHifiGanConfig(PretrainedConfig):
+    """
+    This is the configuration class to store the configuration of a [`FastSpeech2ConformerWithHifiGan`]. It is used to
+    instantiate a `FastSpeech2ConformerWithHifiGanModel` model according to the specified sub-models configurations,
+    defining the model architecture.
+
+    Instantiating a configuration with the defaults will yield a similar configuration to that of the
+    FastSpeech2ConformerModel [espnet/fastspeech2_conformer](https://huggingface.co/espnet/fastspeech2_conformer) and
+    FastSpeech2ConformerHifiGan
+    [espnet/fastspeech2_conformer_hifigan](https://huggingface.co/espnet/fastspeech2_conformer_hifigan) architectures.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        model_config (`typing.Dict`, *optional*):
+            Configuration of the text-to-speech model.
+        vocoder_config (`typing.Dict`, *optional*):
+            Configuration of the vocoder model.
+    model_config ([`FastSpeech2ConformerConfig`], *optional*):
+        Configuration of the text-to-speech model.
+    vocoder_config ([`FastSpeech2ConformerHiFiGanConfig`], *optional*):
+        Configuration of the vocoder model.
+
+    Example:
+
+    ```python
+    >>> from transformers import (
+    ...     FastSpeech2ConformerConfig,
+    ...     FastSpeech2ConformerHifiGanConfig,
+    ...     FastSpeech2ConformerWithHifiGanConfig,
+    ...     FastSpeech2ConformerWithHifiGan,
+    ... )
+
+    >>> # Initializing FastSpeech2ConformerWithHifiGan sub-modules configurations.
+    >>> model_config = FastSpeech2ConformerConfig()
+    >>> vocoder_config = FastSpeech2ConformerHifiGanConfig()
+
+    >>> # Initializing a FastSpeech2ConformerWithHifiGan module style configuration
+    >>> configuration = FastSpeech2ConformerWithHifiGanConfig(model_config.to_dict(), vocoder_config.to_dict())
+
+    >>> # Initializing a model (with random weights)
+    >>> model = FastSpeech2ConformerWithHifiGan(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```
+    """
+
+    model_type = "fastspeech2_conformer_with_hifigan"
+    sub_configs = {"model_config": FastSpeech2ConformerConfig, "vocoder_config": FastSpeech2ConformerHifiGanConfig}
+
+    def __init__(
+        self,
+        model_config: Optional[dict] = None,
+        vocoder_config: Optional[dict] = None,
+        **kwargs,
+    ):
+        if model_config is None:
+            model_config = {}
+            logger.info("model_config is None. initializing the model with default values.")
+
+        if vocoder_config is None:
+            vocoder_config = {}
+            logger.info("vocoder_config is None. initializing the coarse model with default values.")
+
+        self.model_config = FastSpeech2ConformerConfig(**model_config)
+        self.vocoder_config = FastSpeech2ConformerHifiGanConfig(**vocoder_config)
+
+        super().__init__(**kwargs)
+
+
+__all__ = ["FastSpeech2ConformerConfig", "FastSpeech2ConformerHifiGanConfig", "FastSpeech2ConformerWithHifiGanConfig"]
diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a2dc39385b3c953eb256bb03c04d6456c8f8890
--- /dev/null
+++ b/URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py
@@ -0,0 +1,1588 @@
+# coding=utf-8
+# Copyright 2023 The Espnet authors, IMS Toucan authors, and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch FastSpeech2Conformer model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import ModelOutput, auto_docstring, logging
+from .configuration_fastspeech2_conformer import (
+    FastSpeech2ConformerConfig,
+    FastSpeech2ConformerHifiGanConfig,
+    FastSpeech2ConformerWithHifiGanConfig,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+    custom_intro="""
+    Output type of [`FastSpeech2ConformerModel`].
+    """
+)
+class FastSpeech2ConformerModelOutput(ModelOutput):
+    r"""
+    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+        Spectrogram generation loss.
+    duration_outputs (`torch.LongTensor` of shape `(batch_size, max_text_length + 1)`, *optional*):
+        Outputs of the duration predictor.
+    pitch_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
+        Outputs of the pitch predictor.
+    energy_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
+        Outputs of the energy predictor.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    spectrogram: Optional[torch.FloatTensor] = None
+    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+    encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+    encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+    decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
+    decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
+    duration_outputs: Optional[torch.LongTensor] = None
+    pitch_outputs: Optional[torch.FloatTensor] = None
+    energy_outputs: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+    custom_intro="""
+    Output type of [`FastSpeech2ConformerWithHifiGan`].
+    """
+)
+class FastSpeech2ConformerWithHifiGanOutput(FastSpeech2ConformerModelOutput):
+    r"""
+    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+        Spectrogram generation loss.
+    duration_outputs (`torch.LongTensor` of shape `(batch_size, max_text_length + 1)`, *optional*):
+        Outputs of the duration predictor.
+    pitch_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
+        Outputs of the pitch predictor.
+    energy_outputs (`torch.FloatTensor` of shape `(batch_size, max_text_length + 1, 1)`, *optional*):
+        Outputs of the energy predictor.
+    waveform (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
+        Speech output as a result of passing the predicted mel spectrogram through the vocoder.
+    """
+
+    waveform: Optional[torch.FloatTensor] = None
+
+
+def length_regulator(encoded_embeddings, duration_labels, speaking_speed=1.0):
+    """
+    Length regulator for feed-forward Transformer.
+
+    This is the length regulator module described in `FastSpeech: Fast, Robust and Controllable Text to Speech`
+    https://huggingface.co/papers/1905.09263. The length regulator expands char or phoneme-level embedding features to
+    frame-level by repeating each feature based on the corresponding predicted durations.
+
+    Args:
+        encoded_embeddings (`torch.Tensor` of shape `(batch_size, max_text_length, embedding_dim)`):
+            Batch of sequences of char or phoneme embeddings.
+        duration_labels (`torch.LongTensor` of shape `(batch_size, time)`):
+            Batch of durations of each frame.
+        speaking_speed (`float`, *optional*, defaults to 1.0):
+            Value to control speed of speech.
+
+    Returns:
+        `torch.Tensor`:
+            Replicated input tensor based on durations (batch_size, time*, embedding_dim).
+    """
+
+    if speaking_speed <= 0:
+        raise ValueError("`speaking_speed` must be greater than 0.")
+    elif speaking_speed != 1.0:
+        duration_labels = torch.round(duration_labels.float() * speaking_speed).long()
+
+    if duration_labels.sum() == 0:
+        duration_labels[duration_labels.sum(dim=1).eq(0)] = 1
+
+    # Calculate the maximum length needed
+    max_len = torch.sum(duration_labels, dim=1).max()
+
+    # Create a padded tensor to hold the results
+    hidden_states = torch.zeros(
+        (encoded_embeddings.size(0), max_len, encoded_embeddings.size(2)),
+        dtype=torch.float,
+        device=encoded_embeddings.device,
+    )
+
+    # Loop through the batch and fill in the data
+    for i, (encoded_embedding, target_duration) in enumerate(zip(encoded_embeddings, duration_labels)):
+        repeated = torch.repeat_interleave(encoded_embedding, target_duration, dim=0)
+        hidden_states[i, : repeated.size(0)] = repeated
+
+    return hidden_states
+
+
+class FastSpeech2ConformerDurationPredictor(nn.Module):
+    """
+    Duration predictor module.
+
+    This is a module of duration predictor described in the paper 'FastSpeech: Fast, Robust and Controllable Text to
+    Speech' https://huggingface.co/papers/1905.09263 The duration predictor predicts a duration of each frame in log domain
+    from the hidden embeddings of encoder.
+
+    Note:
+        The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`, the
+        outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
+
+    """
+
+    def __init__(self, config: FastSpeech2ConformerConfig):
+        super().__init__()
+
+        self.conv_layers = nn.ModuleList()
+        self.log_domain_offset = 1.0
+
+        for layer_idx in range(config.duration_predictor_layers):
+            num_chans = config.duration_predictor_channels
+            input_channels = config.hidden_size if layer_idx == 0 else num_chans
+            layer = FastSpeech2ConformerPredictorLayer(
+                input_channels,
+                num_chans,
+                config.duration_predictor_kernel_size,
+                config.duration_predictor_dropout_rate,
+            )
+            self.conv_layers.append(layer)
+        self.linear = nn.Linear(config.duration_predictor_channels, 1)
+
+    def forward(self, encoder_hidden_states):
+        """
+        Args:
+            hidden_states (`torch.Tensor` of shape `(batch_size, max_text_length, input_dim)`):
+                Batch of input sequences.
+            padding_masks (`torch.ByteTensor` of shape `(batch_size, max_text_length)`, *optional*):
+                Batch of masks indicating padded part.
+
+        Returns:
+            `torch.Tensor`: Batch of predicted durations in log domain `(batch_size, max_text_length)`.
+
+        """
+        # (batch_size, input_dim, max_text_length)
+        hidden_states = encoder_hidden_states.transpose(1, -1)
+        for layer in self.conv_layers:
+            hidden_states = layer(hidden_states)
+
+        # NOTE: calculate in log domain, (batch_size, max_text_length)
+        hidden_states = self.linear(hidden_states.transpose(1, -1)).squeeze(-1)
+
+        if not self.training:
+            # NOTE: calculate in linear domain
+            hidden_states = torch.clamp(torch.round(hidden_states.exp() - self.log_domain_offset), min=0).long()
+
+        return hidden_states
+
+
+# Copied from transformers.models.speecht5.modeling_speecht5.SpeechT5BatchNormConvLayer
+class FastSpeech2ConformerBatchNormConvLayer(nn.Module):
+    def __init__(self, config, layer_id=0):
+        super().__init__()
+
+        if layer_id == 0:
+            in_conv_dim = config.num_mel_bins
+        else:
+            in_conv_dim = config.speech_decoder_postnet_units
+
+        if layer_id == config.speech_decoder_postnet_layers - 1:
+            out_conv_dim = config.num_mel_bins
+        else:
+            out_conv_dim = config.speech_decoder_postnet_units
+
+        self.conv = nn.Conv1d(
+            in_conv_dim,
+            out_conv_dim,
+            kernel_size=config.speech_decoder_postnet_kernel,
+            stride=1,
+            padding=(config.speech_decoder_postnet_kernel - 1) // 2,
+            bias=False,
+        )
+        self.batch_norm = nn.BatchNorm1d(out_conv_dim)
+
+        if layer_id < config.speech_decoder_postnet_layers - 1:
+            self.activation = nn.Tanh()
+        else:
+            self.activation = None
+
+        self.dropout = nn.Dropout(config.speech_decoder_postnet_dropout)
+
+    def forward(self, hidden_states):
+        hidden_states = self.conv(hidden_states)
+        hidden_states = self.batch_norm(hidden_states)
+        if self.activation is not None:
+            hidden_states = self.activation(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        return hidden_states
+
+
+class FastSpeech2ConformerSpeechDecoderPostnet(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.feat_out = nn.Linear(config.hidden_size, config.num_mel_bins * config.reduction_factor)
+        self.layers = nn.ModuleList(
+            [FastSpeech2ConformerBatchNormConvLayer(config, i) for i in range(config.speech_decoder_postnet_layers)]
+        )
+
+    def forward(self, hidden_states: torch.Tensor):
+        outputs_before_postnet = self.feat_out(hidden_states).view(hidden_states.size(0), -1, self.config.num_mel_bins)
+        layer_output = outputs_before_postnet.transpose(1, 2)
+        for layer in self.layers:
+            layer_output = layer(layer_output)
+        outputs_after_postnet = outputs_before_postnet + layer_output.transpose(1, 2)
+        return outputs_before_postnet, outputs_after_postnet
+
+
+class FastSpeech2ConformerPredictorLayer(nn.Module):
+    def __init__(self, input_channels, num_chans, kernel_size, dropout_rate):
+        super().__init__()
+        self.conv = nn.Conv1d(
+            input_channels,
+            num_chans,
+            kernel_size,
+            stride=1,
+            padding=(kernel_size - 1) // 2,
+        )
+        self.activation = nn.ReLU()
+        self.layer_norm = nn.LayerNorm(num_chans)
+        self.dropout = nn.Dropout(dropout_rate)
+
+    def forward(self, hidden_states):
+        hidden_states = self.conv(hidden_states)
+        hidden_states = self.activation(hidden_states)
+
+        # Perform layer norm on dimension 1
+        hidden_states = hidden_states.transpose(1, -1)
+        hidden_states = self.layer_norm(hidden_states)
+        hidden_states = hidden_states.transpose(1, -1)
+
+        hidden_states = self.dropout(hidden_states)
+
+        return hidden_states
+
+
+class FastSpeech2ConformerVariancePredictor(nn.Module):
+    def __init__(
+        self,
+        config: FastSpeech2ConformerConfig,
+        num_layers=2,
+        num_chans=384,
+        kernel_size=3,
+        dropout_rate=0.5,
+    ):
+        """
+        Initialize variance predictor module.
+
+        Args:
+            input_dim (`int`): Input dimension.
+            num_layers (`int`, *optional*, defaults to 2): Number of convolutional layers.
+            num_chans (`int`, *optional*, defaults to 384): Number of channels of convolutional layers.
+            kernel_size (`int`, *optional*, defaults to 3): Kernel size of convolutional layers.
+            dropout_rate (`float`, *optional*, defaults to 0.5): Dropout rate.
+        """
+        super().__init__()
+        self.conv_layers = nn.ModuleList()
+        for idx in range(num_layers):
+            input_channels = config.hidden_size if idx == 0 else num_chans
+            layer = FastSpeech2ConformerPredictorLayer(input_channels, num_chans, kernel_size, dropout_rate)
+            self.conv_layers.append(layer)
+        self.linear = nn.Linear(num_chans, 1)
+
+    def forward(self, encoder_hidden_states, padding_masks=None):
+        """
+        Calculate forward propagation.
+
+        Args:
+            encoder_hidden_states (`torch.Tensor` of shape `(batch_size, max_text_length, input_dim)`):
+                Batch of input sequences.
+            padding_masks (`torch.ByteTensor` of shape `(batch_size, max_text_length)`, *optional*):
+                Batch of masks indicating padded part.
+
+        Returns:
+            Tensor: Batch of predicted sequences `(batch_size, max_text_length, 1)`.
+        """
+        # (batch_size, input_dim, max_text_length)
+        hidden_states = encoder_hidden_states.transpose(1, -1)
+        for layer in self.conv_layers:
+            hidden_states = layer(hidden_states)
+
+        hidden_states = self.linear(hidden_states.transpose(1, 2))
+
+        if padding_masks is not None:
+            hidden_states = hidden_states.masked_fill(padding_masks, 0.0)
+
+        return hidden_states
+
+
+class FastSpeech2ConformerVarianceEmbedding(nn.Module):
+    def __init__(
+        self,
+        in_channels=1,
+        out_channels=384,
+        kernel_size=1,
+        padding=0,
+        dropout_rate=0.0,
+    ):
+        super().__init__()
+        self.conv = nn.Conv1d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            padding=padding,
+        )
+        self.dropout = nn.Dropout(dropout_rate)
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.transpose(1, 2)
+        hidden_states = self.conv(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = hidden_states.transpose(1, 2)
+        return hidden_states
+
+
+class FastSpeech2ConformerAttention(nn.Module):
+    """
+    Multi-Head attention layer with relative position encoding. Details can be found in
+    https://github.com/espnet/espnet/pull/2816. Paper: https://huggingface.co/papers/1901.02860.
+    """
+
+    def __init__(self, config: FastSpeech2ConformerConfig, module_config):
+        """Construct an FastSpeech2ConformerAttention object."""
+        super().__init__()
+        # We assume d_v always equals dim_key
+        self.num_heads = module_config["num_attention_heads"]
+        self.hidden_size = config.hidden_size
+        self.dim_key = self.hidden_size // self.num_heads
+        self.head_dim = self.hidden_size // self.num_heads
+        self.linear_q = nn.Linear(self.hidden_size, self.hidden_size)
+        self.linear_k = nn.Linear(self.hidden_size, self.hidden_size)
+        self.linear_v = nn.Linear(self.hidden_size, self.hidden_size)
+        self.linear_out = nn.Linear(self.hidden_size, self.hidden_size)
+        self.dropout = nn.Dropout(p=module_config["attention_dropout_rate"])
+
+        # linear transformation for positional encoding
+        self.linear_pos = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
+        # these two learnable bias are used in matrix c and matrix d
+        # as described in https://huggingface.co/papers/1901.02860 Section 3.3
+        self.pos_bias_u = nn.Parameter(torch.Tensor(self.num_heads, self.head_dim))
+        self.pos_bias_v = nn.Parameter(torch.Tensor(self.num_heads, self.head_dim))
+
+    def shift_relative_position_tensor(self, pos_tensor):
+        """
+        Args:
+            pos_tensor (torch.Tensor of shape (batch_size, head, time1, 2*time1-1)): Input tensor.
+        """
+        zero_pad = torch.zeros((*pos_tensor.size()[:3], 1), device=pos_tensor.device, dtype=pos_tensor.dtype)
+        pos_tensor_padded = torch.cat([zero_pad, pos_tensor], dim=-1)
+
+        pos_tensor_padded = pos_tensor_padded.view(*pos_tensor.size()[:2], pos_tensor.size(3) + 1, pos_tensor.size(2))
+        # only keep the positions from 0 to time2
+        pos_tensor = pos_tensor_padded[:, :, 1:].view_as(pos_tensor)[:, :, :, : pos_tensor.size(-1) // 2 + 1]
+
+        return pos_tensor
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        pos_emb: Optional[torch.Tensor] = None,
+        output_attentions: Optional[torch.Tensor] = False,
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        """
+        Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+
+        Args:
+            hidden_states (`torch.Tensor` of shape `(batch, time2, size)`): Values of the hidden states
+            attention_mask (`torch.Tensor` of shape `(batch, time1, time2)`): Mask tensor.
+            pos_emb (`torch.Tensor` of shape `(batch, 2*time1-1, size)`): Positional embedding tensor.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        Returns:
+            `torch.Tensor`: Output tensor of shape `(batch, time1, d_model)`.
+        """
+        bsz, q_len, _ = hidden_states.size()
+        query_states = self.linear_q(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
+        key_states = self.linear_k(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
+        value_states = self.linear_v(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
+
+        bsz_pos = pos_emb.size(0)
+        pos_encoding = self.linear_pos(pos_emb).view(bsz_pos, -1, self.num_heads, self.head_dim)
+
+        # (batch_size, head, time1, dim_key)
+        query_with_bias_u = (query_states + self.pos_bias_u).transpose(1, 2)
+        # (batch_size, head, time1, dim_key)
+        query_with_bias_v = (query_states + self.pos_bias_v).transpose(1, 2)
+
+        # compute attention score
+        # first compute matrix a and matrix c
+        # as described in https://huggingface.co/papers/1901.02860 Section 3.3
+        # (batch_size, head, time1, time2)
+        matrix_ac = torch.matmul(query_with_bias_u, key_states.permute(0, 2, 3, 1))
+
+        # compute matrix b and matrix d
+        # (batch_size, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(query_with_bias_v, pos_encoding.permute(0, 2, 3, 1))
+        matrix_bd = self.shift_relative_position_tensor(matrix_bd)
+
+        # (batch_size, head, time1, time2)
+        scores = (matrix_ac + matrix_bd) / math.sqrt(self.dim_key)
+
+        # Forward attention
+        if attention_mask is not None:
+            expected_size = (bsz, 1, q_len)
+            if attention_mask.size() != expected_size:
+                raise ValueError(f"Attention mask should be of size {expected_size}, but is {attention_mask.size()}")
+            attention_mask = attention_mask.unsqueeze(1).eq(0)
+            min_value = float(torch.finfo(scores.dtype).min)
+            scores = scores.masked_fill(attention_mask, min_value)
+            attn_weights = torch.softmax(scores, dim=-1).masked_fill(attention_mask, 0.0)
+        else:
+            attn_weights = torch.softmax(scores, dim=-1)
+
+        attn_weights = self.dropout(attn_weights)
+        attn_output = torch.matmul(attn_weights, value_states.transpose(1, 2))
+        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
+
+        attn_output = self.linear_out(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights
+
+
+class FastSpeech2ConformerConvolutionModule(nn.Module):
+    def __init__(self, config: FastSpeech2ConformerConfig, module_config=None):
+        """
+        Args:
+            config (FastSpeech2ConformerConfig): Configuration for the model.
+            module_config (dict): Configuration for the module (e.g., encoder or decoder).
+        """
+        super().__init__()
+        channels = config.hidden_size
+        # kernel_size should be an odd number for 'SAME' padding
+        if module_config is None:
+            # e.g. using `ParakeetEncoderConfig` in src/transformers/models/parakeet/configuration_parakeet.py
+            kernel_size = config.conv_kernel_size
+            self.activation = ACT2FN[getattr(config, "hidden_act", "silu")]
+        else:
+            kernel_size = module_config["kernel_size"]
+            self.activation = ACT2FN[module_config.get("activation", "silu")]
+        self.padding = (kernel_size - 1) // 2
+        self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True)
+        self.depthwise_conv = nn.Conv1d(
+            channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True
+        )
+        self.norm = nn.BatchNorm1d(channels)
+        self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True)
+
+    def forward(self, hidden_states, attention_mask=None):
+        """
+        Compute convolution module.
+
+        Args:
+            hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
+            attention_mask (`torch.Tensor` of shape `(batch, 1, time)`): Attention mask.
+
+        Returns:
+            `torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
+
+        """
+        # exchange the temporal dimension and the feature dimension
+        hidden_states = hidden_states.transpose(1, 2)
+
+        # GLU mechanism, (batch_size, 2*channel, dim)
+        hidden_states = self.pointwise_conv1(hidden_states)
+        # (batch_size, channel, dim)
+        hidden_states = nn.functional.glu(hidden_states, dim=1)
+
+        # Apply padding mask before convolution
+        if attention_mask is not None:
+            all_masked_rows = torch.all(~attention_mask, dim=-1)
+            hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
+
+        # 1D Depthwise Conv
+        hidden_states = self.depthwise_conv(hidden_states)
+        hidden_states = self.norm(hidden_states)
+        hidden_states = self.activation(hidden_states)
+        hidden_states = self.pointwise_conv2(hidden_states)
+
+        return hidden_states.transpose(1, 2)
+
+
+class FastSpeech2ConformerEncoderLayer(nn.Module):
+    def __init__(self, config: FastSpeech2ConformerConfig, module_config):
+        super().__init__()
+
+        # self-attention module definition
+        self.self_attn = FastSpeech2ConformerAttention(config, module_config)
+
+        # feed-forward module definition
+        self.feed_forward = FastSpeech2ConformerMultiLayeredConv1d(config, module_config)
+
+        self.macaron_style = config.use_macaron_style_in_conformer
+        if self.macaron_style:
+            self.feed_forward_macaron = FastSpeech2ConformerMultiLayeredConv1d(config, module_config)
+            self.ff_macaron_layer_norm = nn.LayerNorm(config.hidden_size)
+            self.ff_scale = 0.5
+        else:
+            self.ff_scale = 1.0
+
+        # convolution module definition
+        self.use_cnn_module = config.use_cnn_in_conformer
+        if self.use_cnn_module:
+            self.conv_module = FastSpeech2ConformerConvolutionModule(config, module_config)
+            self.conv_layer_norm = nn.LayerNorm(config.hidden_size)
+            self.final_layer_norm = nn.LayerNorm(config.hidden_size)
+
+        self.ff_layer_norm = nn.LayerNorm(config.hidden_size)
+
+        self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
+
+        self.dropout = nn.Dropout(module_config["dropout_rate"])
+        self.size = config.hidden_size
+        self.normalize_before = module_config["normalize_before"]
+        self.concat_after = module_config["concat_after"]
+        if self.concat_after:
+            self.concat_linear = nn.Linear(config.hidden_size + config.hidden_size, config.hidden_size)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        pos_emb: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[torch.Tensor] = False,
+    ):
+        """
+        Compute encoded features.
+
+        Args:
+            hidden_states (`torch.Tensor` of shape `(batch, time, size)`): Input tensor.
+            pos_emb (`torch.Tensor` of shape `(1, time, size)`): Positional embeddings tensor.
+            attention_mask (`torch.Tensor` of shape `(batch, time)`): Attention mask tensor for the input.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        Returns:
+            `torch.Tensor`: Output tensor of shape `(batch, time, size)`.
+
+        """
+        # whether to use macaron style
+        if self.macaron_style:
+            residual = hidden_states
+            if self.normalize_before:
+                hidden_states = self.ff_macaron_layer_norm(hidden_states)
+            hidden_states = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(hidden_states))
+            if not self.normalize_before:
+                hidden_states = self.ff_macaron_layer_norm(hidden_states)
+
+        # multi-headed self-attention module
+        residual = hidden_states
+        if self.normalize_before:
+            hidden_states = self.self_attn_layer_norm(hidden_states)
+
+        attention_output, attention_scores = self.self_attn(
+            hidden_states, attention_mask=attention_mask, pos_emb=pos_emb, output_attentions=output_attentions
+        )
+
+        if self.concat_after:
+            x_concat = torch.cat((hidden_states, attention_output), dim=-1)
+            hidden_states = self.concat_linear(x_concat)
+            hidden_states = residual + hidden_states
+        else:
+            hidden_states = self.dropout(attention_output)
+            hidden_states = residual + hidden_states
+        if not self.normalize_before:
+            hidden_states = self.self_attn_layer_norm(hidden_states)
+
+        # convolution module
+        if self.use_cnn_module:
+            residual = hidden_states
+            if self.normalize_before:
+                hidden_states = self.conv_layer_norm(hidden_states)
+            hidden_states = self.conv_module(hidden_states)
+            hidden_states = self.dropout(hidden_states)
+            hidden_states = residual + hidden_states
+            if not self.normalize_before:
+                hidden_states = self.conv_layer_norm(hidden_states)
+
+        # feed forward module
+        residual = hidden_states
+        if self.normalize_before:
+            hidden_states = self.ff_layer_norm(hidden_states)
+        hidden_states = self.feed_forward(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = residual + self.ff_scale * hidden_states
+        if not self.normalize_before:
+            hidden_states = self.ff_layer_norm(hidden_states)
+
+        if self.conv_module is not None:
+            hidden_states = self.final_layer_norm(hidden_states)
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (attention_scores,)
+
+        return outputs
+
+
+class FastSpeech2ConformerMultiLayeredConv1d(nn.Module):
+    """
+    Multi-layered conv1d for Transformer block.
+
+    This is a module of multi-layered conv1d designed to replace positionwise feed-forward network in Transformer
+    block, which is introduced in 'FastSpeech: Fast, Robust and Controllable Text to Speech'
+    https://huggingface.co/papers/1905.09263
+    """
+
+    def __init__(self, config: FastSpeech2ConformerConfig, module_config):
+        """
+        Initialize FastSpeech2ConformerMultiLayeredConv1d module.
+
+        Args:
+            input_channels (`int`): Number of input channels.
+            hidden_channels (`int`): Number of hidden channels.
+            kernel_size (`int`): Kernel size of conv1d.
+            dropout_rate (`float`): Dropout rate.
+        """
+        super().__init__()
+        input_channels = config.hidden_size
+        hidden_channels = module_config["linear_units"]
+        kernel_size = config.positionwise_conv_kernel_size
+        self.conv1 = nn.Conv1d(input_channels, hidden_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2)
+        self.conv2 = nn.Conv1d(hidden_channels, input_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2)
+        self.dropout = nn.Dropout(module_config["dropout_rate"])
+
+    def forward(self, hidden_states):
+        """
+        Calculate forward propagation.
+
+        Args:
+            hidden_states (torch.Tensor): Batch of input tensors (batch_size, time, input_channels).
+
+        Returns:
+            torch.Tensor: Batch of output tensors (batch_size, time, hidden_channels).
+        """
+        hidden_states = hidden_states.transpose(-1, 1)
+        hidden_states = self.conv1(hidden_states)
+        hidden_states = torch.relu(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.conv2(hidden_states)
+        hidden_states = hidden_states.transpose(-1, 1)
+        return hidden_states
+
+
+class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
+    """
+    Args:
+    Relative positional encoding module (new implementation). Details can be found in
+    https://github.com/espnet/espnet/pull/2816. See : Appendix Batch in https://huggingface.co/papers/1901.02860
+        config (`FastSpeech2ConformerConfig`):
+            FastSpeech2ConformerConfig instance.
+        module_config (`dict`):
+            Dictionary containing the encoder or decoder module configuration from the `FastSpeech2ConformerConfig`.
+    """
+
+    def __init__(self, config: FastSpeech2ConformerConfig, module_config):
+        """
+        Construct an PositionalEncoding object.
+        """
+        super().__init__()
+        self.embed_dim = config.hidden_size
+        self.input_scale = math.sqrt(self.embed_dim)
+        self.dropout = nn.Dropout(p=module_config["positional_dropout_rate"])
+        self.pos_enc = None
+        self.max_len = 5000
+        self.extend_pos_enc(torch.tensor(0.0).expand(1, self.max_len))
+
+    def extend_pos_enc(self, x):
+        """Reset the positional encodings."""
+        if self.pos_enc is not None:
+            # self.pos_enc contains both positive and negative parts
+            # the length of self.pos_enc is 2 * input_len - 1
+            if self.pos_enc.size(1) >= x.size(1) * 2 - 1:
+                if self.pos_enc.dtype != x.dtype or self.pos_enc.device != x.device:
+                    self.pos_enc = self.pos_enc.to(dtype=x.dtype, device=x.device)
+                return
+        # Suppose `i` means to the position of query vector and `j` means the
+        # position of key vector. We use position relative positions when keys
+        # are to the left (i>j) and negative relative positions otherwise (i