BryanW commited on
Commit
2710119
·
verified ·
1 Parent(s): 63f0c29

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/__init__.cpython-312.pyc +0 -0
  2. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/beam_constraints.cpython-312.pyc +0 -0
  3. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/beam_search.cpython-312.pyc +0 -0
  4. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/candidate_generator.cpython-312.pyc +0 -0
  5. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/configuration_utils.cpython-312.pyc +0 -0
  6. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/flax_logits_process.cpython-312.pyc +0 -0
  7. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/flax_utils.cpython-312.pyc +0 -0
  8. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/stopping_criteria.cpython-312.pyc +0 -0
  9. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/streamers.cpython-312.pyc +0 -0
  10. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/tf_logits_process.cpython-312.pyc +0 -0
  11. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/watermarking.cpython-312.pyc +0 -0
  12. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__init__.py +26 -0
  13. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/__init__.cpython-312.pyc +0 -0
  14. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/cache.cpython-312.pyc +0 -0
  15. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/cache_manager.cpython-312.pyc +0 -0
  16. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/continuous_api.cpython-312.pyc +0 -0
  17. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/requests.cpython-312.pyc +0 -0
  18. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/scheduler.cpython-312.pyc +0 -0
  19. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/cache.py +606 -0
  20. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/cache_manager.py +226 -0
  21. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/continuous_api.py +1047 -0
  22. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/requests.py +204 -0
  23. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/scheduler.py +300 -0
  24. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/__init__.py +31 -0
  25. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/configuration_albert.py +170 -0
  26. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_albert.py +1349 -0
  27. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_flax_albert.py +1132 -0
  28. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_tf_albert.py +1572 -0
  29. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/tokenization_albert.py +320 -0
  30. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/tokenization_albert_fast.py +178 -0
  31. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/__init__.py +35 -0
  32. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py +882 -0
  33. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/configuration_auto.py +1404 -0
  34. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/feature_extraction_auto.py +422 -0
  35. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/image_processing_auto.py +688 -0
  36. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_auto.py +0 -0
  37. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_flax_auto.py +413 -0
  38. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_tf_auto.py +776 -0
  39. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/processing_auto.py +443 -0
  40. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/tokenization_auto.py +1235 -0
  41. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/video_processing_auto.py +393 -0
  42. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/__init__.py +28 -0
  43. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/configuration_bark.py +303 -0
  44. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/generation_configuration_bark.py +330 -0
  45. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/modeling_bark.py +1628 -0
  46. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/processing_bark.py +340 -0
  47. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/__init__.py +32 -0
  48. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/configuration_bert.py +154 -0
  49. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py +1801 -0
  50. URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_flax_bert.py +1727 -0
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (7.34 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/beam_constraints.cpython-312.pyc ADDED
Binary file (23.1 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/beam_search.cpython-312.pyc ADDED
Binary file (44.8 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/candidate_generator.cpython-312.pyc ADDED
Binary file (61 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/configuration_utils.cpython-312.pyc ADDED
Binary file (76.5 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/flax_logits_process.cpython-312.pyc ADDED
Binary file (32.4 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/flax_utils.cpython-312.pyc ADDED
Binary file (46.5 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/stopping_criteria.cpython-312.pyc ADDED
Binary file (31.2 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/streamers.cpython-312.pyc ADDED
Binary file (14.7 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/tf_logits_process.cpython-312.pyc ADDED
Binary file (40.4 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/watermarking.cpython-312.pyc ADDED
Binary file (28 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from .cache import PagedAttentionCache
16
+ from .continuous_api import ContinuousBatchingManager, ContinuousMixin
17
+ from .requests import RequestState, RequestStatus
18
+
19
+
20
+ __all__ = [
21
+ "ContinuousBatchingManager",
22
+ "ContinuousMixin",
23
+ "PagedAttentionCache",
24
+ "RequestState",
25
+ "RequestStatus",
26
+ ]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (501 Bytes). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/cache.cpython-312.pyc ADDED
Binary file (31.7 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/cache_manager.cpython-312.pyc ADDED
Binary file (12.4 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/continuous_api.cpython-312.pyc ADDED
Binary file (55.7 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/requests.cpython-312.pyc ADDED
Binary file (10.5 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/scheduler.cpython-312.pyc ADDED
Binary file (16 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/cache.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from collections import deque
16
+ from math import floor, gcd, sqrt
17
+ from typing import Optional, Union
18
+
19
+ import torch
20
+
21
+ from ...configuration_utils import PretrainedConfig
22
+ from ...generation.configuration_utils import GenerationConfig
23
+ from ...utils.metrics import attach_tracer, traced
24
+ from .cache_manager import CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
25
+ from .requests import get_device_and_memory_breakdown, logger
26
+
27
+
28
+ def group_layers_by_attn_type(config: PretrainedConfig) -> tuple[list[list[int]], list[str]]:
29
+ """
30
+ Group layers depending on the attention mix, according to VLLM's hybrid allocator rules:
31
+ - Layers in each group need to have the same type of attention
32
+ - All groups have the same number of layers
33
+
34
+ For a model with the following layer types: ["sliding", "full", "full", "sliding", "full", "full", "full", "full"]
35
+ We would get two groups: [0, 3] and [1, 2], [4,5], [6,7].
36
+ """
37
+ # If the config has no layer_type attribute, it means all layers are the same attention type
38
+ layer_types = getattr(config, "layer_types", None)
39
+ if layer_types is None:
40
+ attn_type = "sliding_attention" if getattr(config, "sliding_window", None) is not None else "full_attention"
41
+ layer_types = [attn_type for _ in range(config.num_hidden_layers)]
42
+
43
+ # We then count the number of layers of each type
44
+ layer_counts = {}
45
+ for i, layer_type in enumerate(layer_types):
46
+ layer_counts[layer_type] = layer_counts.get(layer_type, []) + [i]
47
+
48
+ # The size of all groups is the greatest common divisor of the number of layers of each type
49
+ group_size = gcd(*[len(indices) for indices in layer_counts.values()])
50
+
51
+ # We then group the layers by type
52
+ layer_groups = []
53
+ for layer_type, indices in layer_counts.items():
54
+ for i in range(0, len(indices), group_size):
55
+ layer_groups.append(indices[i : i + group_size])
56
+ # And note the layer types
57
+ group_types = [layer_types[lg[0]] for lg in layer_groups]
58
+ return layer_groups, group_types
59
+
60
+
61
+ @attach_tracer()
62
+ class PagedAttentionCache:
63
+ """
64
+ Manages the cache for a paged attention mechanism, inspired by VLLM's hybrid allocator. The cache relies on making
65
+ groups of layers to reduce the complexity of cache management and fragmentation.
66
+
67
+ The cache uses a three-level hierarchy:
68
+ - Pages: The smallest unit of cache, a page has a size of [num_heads, head_size], which is the space needed to
69
+ store the key or value states for one token and one layer. For a model with only full-attention layers, to store
70
+ the KV cache of one token, we need `2 * num_layers` pages: key and values each take `num_layers` pages.
71
+ Pages are grouped into blocks:
72
+ - Blocks: A block is a collection of `block_size` pages, serving as the allocation unit to reduce management
73
+ complexity and fragmentation. Cache is allocated and freed block by block, not page by page. One block is
74
+ allocated to one layer group, which only has one attention type, like full-attention or sliding-attention.
75
+ If all layers in the model have the same attention type, then all layers will be in the same group. There is
76
+ more than one group if and only if the model has a mixed attention types, like layers with full-attention and
77
+ layers with sliding-attention.
78
+ - Cache tensors: The physical supports for the cache. There are as many cache tensors as there are layer in a
79
+ layer group, and the shape of the cache tensor is `[num_blocks * block_size, num_heads, head_size]`.
80
+
81
+ Grouping layers into groups is useful because when we allocate one block to a group N, the block allocated is the
82
+ same for all layers in group N, equivalently it is allocated across all cache tensors. This allows us to
83
+ efficiently allocate and free blocks, and to efficiently read and write key and value states.
84
+
85
+ For instance, imagine we have 8 blocks of cache and a model with two layer groups: a full-attention group with 3
86
+ layers and a sliding-attention group with 3 layers. At creation time, the physical cache tensors look like this:
87
+
88
+ cache_tensor_0: □ □ □ □ □ □ □ □
89
+ cache_tensor_1: □ □ □ □ □ □ □ □
90
+ cache_tensor_2: □ □ □ □ □ □ □ □
91
+
92
+ where □ means the blocks is not allocated to any layer group yet. We have 3 cache tensors because there are
93
+ 3 layers per group.
94
+ We allocate 1 block to each group, after allocation, the cache tensors look like this:
95
+
96
+ cache_tensor_0: ✖ ◉ □ □ □ □ □ □
97
+ cache_tensor_1: ✖ ◉ □ □ □ □ □ □
98
+ cache_tensor_2: ✖ ◉ □ □ □ □ □ □
99
+
100
+ where ✖ means the block is allocated to the full-attention group, and ◉ means the block is allocated to the
101
+ sliding-attention group.
102
+ Now, if we continue to generate, and the sliding window has been reached, we only need to allocate a new block
103
+ for the full-attention group, and the cache tensors look like this:
104
+
105
+ cache_tensor_0: ✖ ◉ ✖ □ □ □ □ □
106
+ cache_tensor_1: ✖ ◉ ✖ □ □ □ □ □
107
+ cache_tensor_2: ✖ ◉ ✖ □ □ □ □ □
108
+
109
+ And after further generation, when we need a new block allocated:
110
+
111
+ cache_tensor_0: ✖ ◉ ✖ ✖ □ □ □ □
112
+ cache_tensor_1: ✖ ◉ ✖ ✖ □ □ □ □
113
+ cache_tensor_2: ✖ ◉ ✖ ✖ □ □ □ □
114
+
115
+ This would not have been possible if all layers were in the same group: we would have had to allocate a new block
116
+ for the sliding-attention group, although it is not needed.
117
+ """
118
+
119
+ # TODO: this init is quite long, maybe a refactor is in order
120
+ def __init__(
121
+ self,
122
+ config: PretrainedConfig,
123
+ generation_config: GenerationConfig,
124
+ device: torch.device,
125
+ dtype: torch.dtype = torch.float16,
126
+ layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
127
+ tp_size: Optional[int] = None,
128
+ ) -> None:
129
+ """Initialize a paged attention cache for efficient memory usage.
130
+
131
+ Args:
132
+ config: Model configuration
133
+ generation_config: Generation configuration containing cache parameters
134
+ device: Device for the cache tensors
135
+ dtype: Data type of the cache
136
+ layer_device_map: Optional mapping of layer indices to devices
137
+ tp_size: Tensor parallelism size
138
+ """
139
+ self.config = config
140
+ self.dtype = dtype
141
+ self.device = device
142
+
143
+ # Extract model dimensions
144
+ kv_heads = getattr(config, "num_key_value_heads", None)
145
+ self.num_key_value_heads: int = kv_heads if kv_heads is not None else config.num_attention_heads
146
+ head_dim = getattr(config, "head_dim", None)
147
+ self.head_dim: int = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads
148
+
149
+ # Extract cache dimensions
150
+ self.block_size = getattr(generation_config, "block_size", 32)
151
+
152
+ # Group layers depending on the attention mix
153
+ layer_groups, group_types = group_layers_by_attn_type(config)
154
+ group_size = len(layer_groups[0])
155
+ self.num_groups = len(layer_groups)
156
+
157
+ self.sliding_windows = {}
158
+ self.layer_index_to_group_indices = {}
159
+ for i, group in enumerate(layer_groups):
160
+ sliding_window = config.sliding_window if group_types[i] == "sliding_attention" else 1
161
+ for j, layer in enumerate(group):
162
+ self.layer_index_to_group_indices[layer] = (i, j)
163
+ self.sliding_windows[layer] = sliding_window
164
+
165
+ # Handle TP (or dont)
166
+ if tp_size is not None and tp_size > 1:
167
+ if self.num_key_value_heads % tp_size != 0:
168
+ raise ValueError(
169
+ f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
170
+ )
171
+ # If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
172
+ # self.num_key_value_heads //= tp_size # TODO: why is this commented out?
173
+
174
+ # Infer number of blocks and max batch tokens
175
+ page_size = self.head_dim * self.num_key_value_heads
176
+
177
+ if getattr(config, "attn_implementation", None) == "paged_attention":
178
+ num_attention_masks = 0
179
+ else:
180
+ # TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
181
+ num_attention_masks = 2 if "sliding_attention" in group_types else 1
182
+
183
+ memory_handler = PagedAttentionMemoryHandler(
184
+ block_size=self.block_size,
185
+ page_size=page_size,
186
+ num_groups=self.num_groups,
187
+ group_size=group_size,
188
+ peak_activation_per_token=(config.hidden_size + config.vocab_size),
189
+ num_attention_masks=num_attention_masks,
190
+ )
191
+ num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens(
192
+ num_blocks=getattr(generation_config, "num_blocks", None),
193
+ max_batch_tokens=getattr(generation_config, "max_batch_tokens", None),
194
+ max_memory_percent=getattr(generation_config, "max_memory", 0.9),
195
+ cache_dtype=self.dtype,
196
+ )
197
+
198
+ # Add the inferred attributes to the class
199
+ self.num_blocks = num_blocks
200
+ self.max_batch_tokens = max_batch_tokens
201
+ logger.info(
202
+ f"PagedAttentionCache initialized with {self.num_blocks = }, {self.block_size = }, {page_size = }, "
203
+ f"{self.max_batch_tokens = } {num_attention_masks = }"
204
+ )
205
+
206
+ # Initialize the cache
207
+ self.key_cache: list[torch.Tensor] = []
208
+ self.value_cache: list[torch.Tensor] = []
209
+ # We add one extra token to the cache to handle padding and generally discard unwanted tokens
210
+ self.cache_shape = (num_blocks * self.block_size + 1, self.num_key_value_heads, self.head_dim)
211
+ for _ in range(group_size):
212
+ new_layer_key_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
213
+ new_layer_value_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
214
+ torch._dynamo.mark_static_address(new_layer_key_cache)
215
+ torch._dynamo.mark_static_address(new_layer_value_cache)
216
+ self.key_cache.append(new_layer_key_cache)
217
+ self.value_cache.append(new_layer_value_cache)
218
+ logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")
219
+
220
+ # Block management data structures
221
+ self._free_blocks = deque(range(num_blocks))
222
+ self.group_cache_managers: list[CacheAllocator] = []
223
+ for i, group_type in enumerate(group_types):
224
+ if group_type == "full_attention":
225
+ cm = FullAttentionCacheAllocator(i, self.block_size)
226
+ elif group_type == "sliding_attention":
227
+ cm = SlidingAttentionCacheAllocator(i, self.block_size, config.sliding_window)
228
+ else:
229
+ raise ValueError(f"Invalid group type: {group_type}")
230
+ self.group_cache_managers.append(cm)
231
+
232
+ @traced
233
+ def allocate_blocks(self, n_blocks: int, request_id: str) -> int:
234
+ """Allocate cache blocks across all layer groups for a given request. Actual allocation is done by the cache
235
+ managers, and this method only returns the maximum number of blocks actually allocated across all managers."""
236
+ max_allocated = 0
237
+ for cm in self.group_cache_managers:
238
+ allocated = cm.allocate_blocks(n_blocks, request_id, self._free_blocks)
239
+ if allocated is None:
240
+ return None
241
+ max_allocated = max(max_allocated, allocated)
242
+ return max_allocated
243
+
244
+ @traced
245
+ def free_blocks(self, request_id: str) -> None:
246
+ """Free all allocated cache blocks for a given request across all layer groups. Actual deallocation is done
247
+ by the cache managers."""
248
+ for cm in self.group_cache_managers:
249
+ cm.free_blocks(request_id, self._free_blocks)
250
+
251
+ def get_num_free_blocks(self) -> int:
252
+ """Get the current number of unallocated blocks available for new requests."""
253
+ return len(self._free_blocks)
254
+
255
+ @traced
256
+ def extend_read_indices(
257
+ self, request_id: str, past_length: int, query_length: int, read_index: list[list[int]]
258
+ ) -> None:
259
+ """Retrieve physical cache indices for reading KV states in the cache across all layer groups. This method
260
+ coordinates with all cache managers to build the complete set of read indices needed for attention computation.
261
+ """
262
+ for cm, read_indices in zip(self.group_cache_managers, read_index):
263
+ indices = cm.get_read_indices(request_id, past_length, query_length)
264
+ read_indices.extend(indices)
265
+
266
+ @traced
267
+ def extend_write_indices(
268
+ self, request_id: str, past_length: int, query_length: int, write_index: list[list[int]]
269
+ ) -> None:
270
+ """Retrieve physical cache indices for writing new KV states to the cache across all layer groups. This method
271
+ coordinates with all cache managers to build the complete set of write indices needed to store computed KV
272
+ states."""
273
+ for cm, write_indices in zip(self.group_cache_managers, write_index):
274
+ indices = cm.get_write_indices(request_id, past_length, query_length)
275
+ write_indices.extend(indices)
276
+
277
+ @traced
278
+ def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> dict[str, int]:
279
+ """Retrieve the key sequence length for the given request_id across all layer types. Returns a dictionary of
280
+ layer types to their corresponding key sequence lengths."""
281
+ seqlens_k = {}
282
+ for cm in self.group_cache_managers:
283
+ attn_type, seqlen_k = cm.get_seqlens_k(request_id, past_length, query_length)
284
+ seqlens_k[attn_type] = seqlen_k
285
+ return seqlens_k
286
+
287
+ @traced
288
+ def update(
289
+ self,
290
+ key_states: torch.Tensor, # shape [1, num_kv_heads, seqlen_kv, head_dim]
291
+ value_states: torch.Tensor, # shape [1, num_kv_heads, seqlen_kv, head_dim]
292
+ layer_idx: int,
293
+ read_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_kv + past_length]
294
+ write_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_q]
295
+ **kwargs,
296
+ ) -> tuple[torch.Tensor, torch.Tensor]: # shape [seqlen_kv + past_length, num_kv_heads, head_dim]
297
+ """Update the cache with new key-value states for a specific layer. This method writes new KV states to the
298
+ appropriate cache locations. The behavior differs based on the layer's attention type:
299
+
300
+ - Full attention: New KV states are written to cache, then complete sequence is read from cache
301
+ - Sliding window: Old KV is read from cache along with extra spaces for the new KV, then new KV is written to
302
+ cache. This is because new KV might overwrite the old KV, so we need to read the old KV first.
303
+
304
+ Returns the complete KV states (cached + new) for attention computation.
305
+ """
306
+ # Retrieve the layer read and write indices, and if there is a sliding window
307
+ group_idx, layer_idx_in_group = self.layer_index_to_group_indices[layer_idx]
308
+ layer_read_index = read_index[group_idx]
309
+ layer_write_index = write_index[group_idx]
310
+ # Select the correct cache
311
+ k_cache = self.key_cache[layer_idx_in_group]
312
+ v_cache = self.value_cache[layer_idx_in_group]
313
+ # Transpose the key and value states to match the cache shape, after which shape is [seqlen_kv, num_kv_heads, head_dim]
314
+ key_states = key_states.transpose(1, 2).squeeze(0)
315
+ value_states = value_states.transpose(1, 2).squeeze(0)
316
+
317
+ # Case: full attention
318
+ sliding_window = self.sliding_windows[layer_idx]
319
+ if sliding_window == 1:
320
+ k_cache[layer_write_index, :, :] = key_states
321
+ v_cache[layer_write_index, :, :] = value_states
322
+ key_states_with_cache = k_cache[layer_read_index, :, :]
323
+ value_states_with_cache = v_cache[layer_read_index, :, :]
324
+
325
+ # Case: sliding window -- we need to be careful of read/write order because of chunked prefill, because it's
326
+ # the only case where you may write over cache you need to use
327
+ else:
328
+ # Add the cache to the key and value states
329
+ mask = layer_read_index == -1 # TODO: can this can be efficiently precomputed?
330
+ key_states_with_cache = k_cache[layer_read_index, :, :]
331
+ key_states_with_cache[mask] = key_states
332
+ value_states_with_cache = v_cache[layer_read_index, :, :]
333
+ value_states_with_cache[mask] = value_states
334
+ # Write new KV values to the cache
335
+ k_cache[layer_write_index, :, :] = key_states
336
+ v_cache[layer_write_index, :, :] = value_states
337
+
338
+ # Return the new KV values
339
+ return key_states_with_cache, value_states_with_cache
340
+
341
+
342
+ # TODO: rework computation with the groups and their sizes
343
+ class PagedAttentionMemoryHandler:
344
+ """A helper class to determine the best number of pages and maximum number of tokens per batch for the paged
345
+ attention cache, providing automatic sizing based on available GPU memory.
346
+ The helper works using the number of pages, which is tied to the number of blocks by:
347
+ num_blocks = num_pages // block_size
348
+
349
+ The memory footprint consists of three main components:
350
+ - Cache memory: the space needed to store the cache tensors:
351
+ 2 * layer_group_size * [num_pages, page_size] * cache_dtype
352
+ - Activation memory: the space temporarily taken by the largest activation during the model forward pass:
353
+ peak_activation_per_token * max_tokens_per_batch * activation_dtype_size
354
+ - Static tensors: the space taken by the input/output buffers and metadata tensors for batch processing, sum of:
355
+ - inputs_ids + outputs_ids + position_ids + logits_indices: 4 * max_tokens_per_batch * int32_size
356
+ - attention_mask: num_attention_masks * num_pages * max_tokens_per_batch * activation_dtype_size
357
+ - cumulative_seqlens_q + cumulative_seqlens_k: (1 + 2) * max_tokens_per_batch * int32_size
358
+ - write_index_tensor: num_groups * max_tokens_per_batch * int32_size
359
+ - read_index_tensor: num_groups * (num_pages + max_tokens_per_batch) * int32_size
360
+
361
+ The handler can operate in three modes:
362
+ 1. Auto-sizing: Determines both number of pages and maximum number of tokens per batch using quadratic optimization
363
+ 2. Fixed cache: Calculates max batch tokens given a fixed number of pages
364
+ 3. Fixed batch: Calculates number of pages given a fixed maximum batch size
365
+
366
+ """
367
+
368
+ _activation_dtype = torch.bfloat16
369
+ _input_dtype = torch.int32
370
+ _upper_bound_max_batch_tokens = 256
371
+ _upper_bound_num_blocks = 4096
372
+
373
+ def __init__(
374
+ self,
375
+ block_size: int,
376
+ page_size: int,
377
+ num_groups: int,
378
+ group_size: int,
379
+ peak_activation_per_token: int,
380
+ num_attention_masks: int,
381
+ ) -> None:
382
+ """Initialize the memory handler with the parameters that cannot be automatically inferred.
383
+
384
+ Args:
385
+ block_size: Size of the cache blocks
386
+ page_size: Size of the cache pages
387
+ num_groups: Number of layer groups
388
+ group_size: Number of layers per layer group
389
+ peak_activation_per_token: Maximum size of activation tensor per token, = hidden_size + vocab_size
390
+ num_attention_masks: Number of attention masks, 0 if no attention mask is used, 2 if hybrid model, else 1
391
+ """
392
+ self.block_size = block_size
393
+ self.page_size = page_size
394
+ self.num_groups = num_groups
395
+ self.group_size = group_size
396
+ self.peak_activation_per_token = peak_activation_per_token
397
+ self.num_attention_masks = num_attention_masks
398
+
399
+ @staticmethod
400
+ def get_available_memory(max_memory_percent: float = 1.0) -> int:
401
+ """Calculate available GPU memory for cache allocation, accounting for already allocated tensors.
402
+ This method queries the current memory state and applies the specified percentage limit to determine
403
+ how much memory can be safely used for the paged attention cache.
404
+
405
+ Args:
406
+ max_memory_percent: Fraction of available memory to use (0.0-1.0). 1.0 means use all available memory.
407
+
408
+ Returns:
409
+ int: Available memory in bytes for cache allocation
410
+ """
411
+ _, total, reserved, allocated = get_device_and_memory_breakdown()
412
+ available_memory = total - max(allocated, reserved)
413
+ available_memory = int(available_memory * max_memory_percent)
414
+ return available_memory
415
+
416
+ def infer_num_blocks_and_max_batch_tokens(
417
+ self,
418
+ num_blocks: Optional[int] = None,
419
+ max_batch_tokens: Optional[int] = None,
420
+ max_memory_percent: float = 0.9,
421
+ cache_dtype: torch.dtype = torch.float16,
422
+ ) -> tuple[int, int]:
423
+ """Determine optimal number of blocks and maximum number of tokens per batch based on available memory and
424
+ constraints. Check the class docstring for more details. Naming the number of pages as N and the maximum number
425
+ of tokens per batch as M, the equation solved is:
426
+
427
+ available_memory = sum([
428
+ MN * num_attention_masks * activation_dtype_size,
429
+ 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
430
+ M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
431
+ ])
432
+
433
+ where we already simplified int32_size = 4.
434
+ """
435
+ # If neither num_blocks nor max_batch_tokens are provided, we use a second-order polynomial
436
+ if num_blocks is None and max_batch_tokens is None:
437
+ num_blocks, max_batch_tokens = self.compute_num_blocks_and_max_batch_tokens(
438
+ max_memory_percent, cache_dtype
439
+ )
440
+ # If only num_blocks is provided, we infer the max_batch_tokens
441
+ elif num_blocks is not None and max_batch_tokens is None:
442
+ max_batch_tokens = self.compute_max_batch_tokens(num_blocks, max_memory_percent, cache_dtype)
443
+ # If only max_batch_tokens is provided, we infer the num_blocks
444
+ elif max_batch_tokens is not None and num_blocks is None:
445
+ num_blocks = self.compute_num_blocks(max_batch_tokens, max_memory_percent, cache_dtype)
446
+
447
+ # We check if the memory footprint is too large in all cases
448
+ available_memory = self.get_available_memory(max_memory_percent)
449
+ memory_footprint = self.compute_memory_footprint(
450
+ max_batch_tokens=max_batch_tokens,
451
+ num_blocks=num_blocks,
452
+ cache_dtype=cache_dtype,
453
+ )
454
+ if memory_footprint > available_memory:
455
+ raise MemoryError(f"Memory footprint {memory_footprint} is more than available memory {available_memory}")
456
+ return num_blocks, max_batch_tokens
457
+
458
+ def compute_num_blocks_and_max_batch_tokens(
459
+ self,
460
+ max_memory_percent: float = 0.9,
461
+ cache_dtype: torch.dtype = torch.float16,
462
+ m: float = 0.01,
463
+ ) -> tuple[int, int]:
464
+ """Calculate optimal number of blocks and maximum number of tokens per batch using quadratic optimization when
465
+ neither is fixed. This method assumes a relationship M = m * N where m is a small ratio below 1 and solves the
466
+ resulting quadratic equation to find the optimal N that maximizes utilization within memory constraints. m is
467
+ the amount of cache we can fill with one batch: m=0.01 means a batch fills at most 1% of the cache. The equation
468
+ to solve is:
469
+
470
+ available_memory = sum([
471
+ m * N^2 * num_attention_masks * activation_dtype_size,
472
+ 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
473
+ m * N * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
474
+ ])
475
+ """
476
+ cache_memory = self.get_available_memory(max_memory_percent)
477
+ logger.info(f"Cache memory: {cache_memory}")
478
+
479
+ # Compute second-degree polynomial coefficients
480
+ a = m * self.num_attention_masks * self._activation_dtype.itemsize
481
+ b = 2 * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
482
+ b += m * (self.peak_activation_per_token * self._activation_dtype.itemsize + 28 + 4 * self.num_groups)
483
+ c = -cache_memory
484
+ logger.debug(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }")
485
+
486
+ # Compute discriminant and greatest solution
487
+ discriminant = b**2 - 4 * a * c
488
+ if discriminant < 0:
489
+ raise ValueError(f"Discriminant is negative: {discriminant = }")
490
+ greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
491
+ if greatest_solution < 0:
492
+ raise ValueError(f"Greatest solution is negative: {greatest_solution = }")
493
+
494
+ # Infer number of blocks and max batch tokens
495
+ num_pages = floor(greatest_solution)
496
+ num_blocks = num_pages // self.block_size
497
+ if num_blocks > self._upper_bound_num_blocks:
498
+ logger.info(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }")
499
+ num_blocks = self._upper_bound_num_blocks
500
+ max_batch_tokens = int(greatest_solution * m)
501
+ if max_batch_tokens > self._upper_bound_max_batch_tokens:
502
+ logger.info(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }")
503
+ max_batch_tokens = self._upper_bound_max_batch_tokens
504
+ return num_blocks, max_batch_tokens
505
+
506
+ def compute_max_batch_tokens(
507
+ self,
508
+ num_blocks: int,
509
+ max_memory_percent: float = 0.9,
510
+ cache_dtype: torch.dtype = torch.float16,
511
+ ) -> int:
512
+ """Calculate maximum batch tokens M given a fixed number of cache blocks. The formula for M is given by:
513
+
514
+ M = (available_memory - 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group))
515
+ / (activation_dtype_size * (N * num_attention_masks + peak_activation_per_token) + 28 + 4 * num_group)
516
+ """
517
+ cache_memory = self.get_available_memory(max_memory_percent)
518
+ num_pages = num_blocks * self.block_size
519
+ # Compute numerator
520
+ num = cache_memory
521
+ num -= 2 * num_pages * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
522
+ # Compute denominator
523
+ denum = self._activation_dtype.itemsize * (
524
+ num_pages * self.num_attention_masks + self.peak_activation_per_token
525
+ )
526
+ denum += 28 + 4 * self.num_groups
527
+ # Compute max batch tokens and return
528
+ max_batch_tokens = floor(num / denum)
529
+ if max_batch_tokens > self._upper_bound_max_batch_tokens:
530
+ logger.info(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }")
531
+ max_batch_tokens = self._upper_bound_max_batch_tokens
532
+ return max_batch_tokens
533
+
534
+ def compute_num_blocks(
535
+ self,
536
+ max_batch_tokens: int,
537
+ max_memory_percent: float = 0.9,
538
+ cache_dtype: torch.dtype = torch.float16,
539
+ ) -> int:
540
+ """Calculate number of cache blocks N given a fixed maximum token per token M. The formula for N is given by:
541
+
542
+ N = (available_memory - M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group))
543
+ / (2 * (layer_group_size * page_size * cache_dtype + 2 * num_group) + M * (num_attention_masks * activation_dtype_size))
544
+ """
545
+ cache_memory = self.get_available_memory(max_memory_percent)
546
+ # Compute numerator
547
+ num = cache_memory
548
+ num -= max_batch_tokens * self.peak_activation_per_token * self._activation_dtype.itemsize
549
+ num -= max_batch_tokens * (28 + 4 * self.num_groups)
550
+ # Compute denominator
551
+ denum = 2 * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
552
+ denum += max_batch_tokens * (self.num_attention_masks * self._activation_dtype.itemsize)
553
+ denum += max_batch_tokens * self._activation_dtype.itemsize
554
+ # Compute cache size and return number of blocks
555
+ num_pages = floor(num / denum)
556
+ num_blocks = num_pages // self.block_size
557
+ if num_blocks > self._upper_bound_num_blocks:
558
+ logger.info(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }")
559
+ num_blocks = self._upper_bound_num_blocks
560
+ return num_blocks
561
+
562
+ def compute_memory_footprint(
563
+ self,
564
+ num_blocks: Optional[int] = None,
565
+ max_batch_tokens: Optional[int] = None,
566
+ cache_dtype: torch.dtype = torch.float16,
567
+ ) -> tuple[int, int, int]:
568
+ """Calculate the memory footprint breakdown for a given number of blocks and maximum batch tokens. The memory
569
+ footprint is given by:
570
+
571
+ available_memory = sum([
572
+ MN * num_attention_masks * activation_dtype_size,
573
+ 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
574
+ M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
575
+ ])
576
+ but is broken down below.
577
+ """
578
+ num_pages = num_blocks * self.block_size
579
+
580
+ cache_memory_footprint = 2 * self.group_size * num_pages * self.page_size * cache_dtype.itemsize
581
+
582
+ activation_memory_footprint = self.peak_activation_per_token * self._activation_dtype.itemsize
583
+ activation_memory_footprint *= max_batch_tokens
584
+
585
+ inputs_outputs_positions_and_logits_memory_footprint = 4 * max_batch_tokens * 4 # second 4 is for int32 size
586
+
587
+ attention_memory_footprint = self.num_attention_masks * self._activation_dtype.itemsize
588
+ attention_memory_footprint *= num_pages * max_batch_tokens
589
+
590
+ cumulative_seqlens_memory_footprint = 3 * max_batch_tokens * 4 # 4 is for int32 size
591
+
592
+ write_index_memory_footprint = self.num_groups * max_batch_tokens * 4 # 4 is for int32 size
593
+ read_index_memory_footprint = self.num_groups * (num_pages + max_batch_tokens) * 4 # 4 is for int32 size
594
+
595
+ total_memory_footprint = sum(
596
+ [
597
+ cache_memory_footprint,
598
+ activation_memory_footprint,
599
+ inputs_outputs_positions_and_logits_memory_footprint,
600
+ attention_memory_footprint,
601
+ cumulative_seqlens_memory_footprint,
602
+ write_index_memory_footprint,
603
+ read_index_memory_footprint,
604
+ ]
605
+ )
606
+ return total_memory_footprint
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/cache_manager.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from abc import ABC, abstractmethod
16
+ from collections import deque
17
+ from math import ceil
18
+ from typing import Optional
19
+
20
+ from .requests import logger
21
+
22
+
23
+ class CacheAllocator(ABC):
24
+ """Abstract base class for cache managers. Cache managers keep track of per-request cache allocations, determine
25
+ when a new physical block needs to be allocated and compute physical indices for reading or writing to the cache."""
26
+
27
+ _index: int
28
+ _block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
29
+
30
+ @abstractmethod
31
+ def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
32
+ """Allocates n_blocks for a given request_id. Returns the num of blocks allocated if successful and None
33
+ otherwise."""
34
+ pass
35
+
36
+ def free_blocks(self, request_id: str, free_blocks: deque[int]) -> None:
37
+ """Frees all blocks associated with a request_id."""
38
+ if request_id in self._block_table:
39
+ blocks_to_free = self._block_table.pop(request_id)
40
+ free_blocks.extend(blocks_to_free)
41
+ else:
42
+ logger.warning(
43
+ f"CacheAllocator {self._index} attempted to free blocks for non-existent request_id: {request_id}"
44
+ )
45
+
46
+ @abstractmethod
47
+ def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
48
+ """Returns the physical indices of where to read request_id's cache in the cache tensor."""
49
+ pass
50
+
51
+ @abstractmethod
52
+ def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
53
+ """Returns the physical indices of where to write request_id's cache in the cache tensor."""
54
+ pass
55
+
56
+ @abstractmethod
57
+ def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]:
58
+ """Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
59
+ pass
60
+
61
+
62
+ class FullAttentionCacheAllocator(CacheAllocator):
63
+ """Cache manager for a group of full attention layers."""
64
+
65
+ def __init__(self, index: int, block_size: int) -> None:
66
+ """Initializes the cache manager for a group of full attention layers.
67
+ Args:
68
+ - index: the index of the associated layer group
69
+ - block_size: the size of the blocks in the cache
70
+ """
71
+ self._index = index
72
+ self.block_size = block_size
73
+ self._block_table = {}
74
+
75
+ def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
76
+ """Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
77
+ otherwise. For group of full attention layers, we always allocate the number of requested blocks."""
78
+ if len(free_blocks) < n_blocks:
79
+ return None
80
+ if request_id not in self._block_table:
81
+ self._block_table[request_id] = []
82
+ self._block_table[request_id].extend(free_blocks.popleft() for _ in range(n_blocks))
83
+ return n_blocks
84
+
85
+ def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
86
+ """Returns the physical indices of where to read request_id's cache. For a group of full attention layers, we
87
+ first write the new cache to the cache tensor and then read the entire cache from the beginning to the end."""
88
+ # Retrieve the block table for the request and raise an error if it doesn't exist
89
+ block_table = self._block_table.get(request_id)
90
+ if block_table is None:
91
+ raise ValueError(f"No block table found for request {request_id}")
92
+ # Compute the physical indices
93
+ physical_indices = []
94
+ for i in range(past_length + query_length):
95
+ block_idx = i // self.block_size
96
+ block_offset = i % self.block_size
97
+ physical_index = block_table[block_idx] * self.block_size + block_offset
98
+ physical_indices.append(physical_index)
99
+ return physical_indices
100
+
101
+ def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
102
+ """Returns the physical indices for writing to the cache. For a group of full attention layers, we write the new
103
+ cache as a continuation of the existing cache for the same request."""
104
+ block_table = self._block_table.get(request_id)
105
+ if block_table is None:
106
+ raise ValueError(f"No block table found for request {request_id}")
107
+ # Compute the physical indices
108
+ physical_indices = []
109
+ for i in range(past_length, past_length + query_length):
110
+ block_idx = i // self.block_size
111
+ block_offset = i % self.block_size
112
+ physical_index = block_table[block_idx] * self.block_size + block_offset
113
+ physical_indices.append(physical_index)
114
+ return physical_indices
115
+
116
+ def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]:
117
+ """Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
118
+ seqlens_k = past_length + query_length
119
+ return "full_attention", seqlens_k
120
+
121
+
122
+ class SlidingAttentionCacheAllocator(CacheAllocator):
123
+ """Cache manager for sliding window attention layers."""
124
+
125
+ def __init__(self, index: int, block_size: int, sliding_window: int) -> None:
126
+ """Initializes the cache manager for a group of sliding window attention layers.
127
+ Args:
128
+ - index: the index of the associated layer group
129
+ - block_size: the size of the blocks in the cache
130
+ - sliding_window: the size of the sliding window
131
+ """
132
+ self._index = index
133
+ self.block_size = block_size
134
+ self.sliding_window = sliding_window
135
+ self._max_blocks_per_request = ceil(self.sliding_window / self.block_size)
136
+ self._block_table = {}
137
+
138
+ def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
139
+ """Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
140
+ otherwise. For group of sliding window attention layers, we only allocate up to the point where we can fit an
141
+ entire sliding window in the cache tensor."""
142
+ if request_id not in self._block_table:
143
+ self._block_table[request_id] = []
144
+ # Early return if we are already at the max number of blocks per request
145
+ already_allocated = len(self._block_table[request_id])
146
+ if already_allocated == self._max_blocks_per_request:
147
+ return 0
148
+ # Compute actual number of blocks to allocate
149
+ after_allocation = min(already_allocated + n_blocks, self._max_blocks_per_request)
150
+ actual_n_blocks = after_allocation - already_allocated
151
+ # Classic allocation
152
+ if len(free_blocks) < actual_n_blocks:
153
+ return None
154
+ self._block_table[request_id].extend(free_blocks.popleft() for _ in range(actual_n_blocks))
155
+ return actual_n_blocks
156
+
157
+ def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
158
+ """Returns the physical indices of where to read request_id's cache in the cache tensor.
159
+ For a group of sliding window attention layers, we read from the cache tensor before writing on it, because the
160
+ new cache can overwrite the old one. To form the cache + new key / values states, we read the at most
161
+ sliding_window - 1 cache page and then manually add the new key / values states after. Hence the -1 indices
162
+ which indicate where to store the new key or values indices."""
163
+ # Retrieve the block table for the request and raise an error if it doesn't exist
164
+ block_table = self._block_table.get(request_id)
165
+ if block_table is None:
166
+ raise ValueError(f"No block table found for request {request_id}")
167
+ # Apply sliding window
168
+ start_index = 0 if past_length < self.sliding_window else past_length % self.sliding_window
169
+ cache_length = min(past_length, self.sliding_window - 1)
170
+ # Compute the physical indices
171
+ physical_indices = []
172
+ for i in range(start_index, start_index + cache_length):
173
+ i %= self.sliding_window
174
+ block_idx = i // self.block_size
175
+ block_offset = i % self.block_size
176
+ physical_index = block_table[block_idx] * self.block_size + block_offset
177
+ physical_indices.append(physical_index)
178
+ return physical_indices + [-1] * query_length
179
+
180
+ def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
181
+ """Returns the physical indices of where to write request_id's cache in the cache tensor. For a group of
182
+ sliding window attention layers, we write the new cache in rolling-buffer kind of way: if we reach the end of
183
+ the allocated physical cache, we start writing from the beginning of the physical cache again."""
184
+ # Retrieve the block table for the request and raise an error if it doesn't exist
185
+ block_table = self._block_table.get(request_id)
186
+ if block_table is None:
187
+ raise ValueError(f"No block table found for request {request_id}")
188
+ # Apply sliding window
189
+ start_index = past_length % self.sliding_window
190
+ cache_length = min(query_length, self.sliding_window)
191
+ padding_length = query_length - cache_length
192
+ # Compute the physical indices
193
+ physical_indices = []
194
+ for i in range(start_index, start_index + cache_length):
195
+ i %= self.sliding_window
196
+ block_idx = i // self.block_size
197
+ block_offset = i % self.block_size
198
+ physical_index = block_table[block_idx] * self.block_size + block_offset
199
+ physical_indices.append(physical_index)
200
+ if padding_length > 0:
201
+ physical_indices = [-1] * padding_length + physical_indices
202
+ return physical_indices
203
+
204
+ def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]:
205
+ """Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
206
+ seqlens_k = query_length + min(past_length, self.sliding_window - 1)
207
+ return "sliding_attention", seqlens_k
208
+
209
+
210
+ # TODO: test the impact of this
211
+ # def get_read_indices(self, request_id: str, past_length: int) -> list[int]:
212
+ # # Retrieve the block table for the request and raise an error if it doesn't exist
213
+ # block_table = self._block_table.get(request_id)
214
+ # if block_table is None:
215
+ # raise ValueError(f"No block table found for request {request_id}")
216
+ # # Compute the physical indices
217
+ # physical_indices = []
218
+ # n_left = past_length
219
+ # for block_idx in block_table:
220
+ # block_physical_index = block_idx * self.block_size
221
+ # pages_used = min(self.block_size, n_left)
222
+ # physical_indices.extend(block_physical_index + i for i in range(pages_used))
223
+ # n_left -= pages_used
224
+ # if n_left == 0:
225
+ # return physical_indices
226
+ # raise ValueError(f"Request {request_id} required too many indices: {past_length = } and {len(block_table) = }")
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/continuous_api.py ADDED
@@ -0,0 +1,1047 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import queue
17
+ import threading
18
+ from dataclasses import dataclass
19
+ from functools import partial
20
+ from itertools import count
21
+ from time import perf_counter
22
+ from typing import Optional, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+ from tqdm import tqdm
27
+
28
+ from ...configuration_utils import PretrainedConfig
29
+ from ...generation.configuration_utils import GenerationConfig
30
+ from ...utils.logging import logging
31
+ from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
32
+ from .cache import PagedAttentionCache
33
+ from .requests import GenerationOutput, RequestState, RequestStatus, get_device_and_memory_breakdown, logger
34
+ from .scheduler import SCHEDULER_MAPPING, FIFOScheduler, Scheduler
35
+
36
+
37
+ def build_attention_mask(
38
+ attention_mask: torch.Tensor,
39
+ cumulative_seqlens_q: torch.Tensor,
40
+ cumulative_seqlens_k: torch.Tensor,
41
+ sliding_window: int = 1,
42
+ ) -> None:
43
+ """Builds an attention mask inplace using the cumulative seqlens of the query and key. If given a sliding window, it
44
+ will also apply a sliding window mask on top. The attention mask is not boolean, it uses zeroes and -inf (or its
45
+ equivalent) so it's more of an attention score bias tensor.
46
+ The attention mask is a block-diagonal matrix, with each block an attention mask for a single query-key pair.
47
+ Each of those block is built from a causal mask and, if there is a sliding window, a sliding window mask.
48
+
49
+ An example is represented below, with seqlen_k = 8, seqlen_q = 4 and sliding_window = 6:
50
+
51
+ CAUSAL MASK:
52
+
53
+ █ █ █ █ █ ░ ░ ░
54
+ █ █ █ █ █ █ ░ ░
55
+ █ █ █ █ █ █ █ ░
56
+ █ █ █ █ █ █ █ █
57
+
58
+ SLIDING WINDOW MASK:
59
+ ┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 8 - 4 - 6 = -2 offset to the right
60
+ <─┴─>
61
+ ░ █ | █ █ █ █ █ █ █ █
62
+ ░ ░ | █ █ █ █ █ █ █ █
63
+ ░ ░ | ░ █ █ █ █ █ █ █
64
+ ░ ░ | ░ ░ █ █ █ █ █ █
65
+
66
+ ATTENTION MASK (sum of causal and sliding window masks):
67
+
68
+ █ █ █ █ █ ░ ░ ░
69
+ █ █ █ █ █ █ ░ ░
70
+ ░ █ █ █ █ █ █ ░
71
+ ░ ░ █ █ █ █ █ █
72
+
73
+ Another example with seqlen_k = 5, seqlen_q = 3 and sliding_window = 2:
74
+
75
+ CAUSAL MASK:
76
+
77
+ █ █ █ ░ ░
78
+ █ █ █ █ ░
79
+ █ █ █ █ █
80
+
81
+ SLIDING WINDOW MASK:
82
+ ┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 5 - 3 - 2 = 0 offset to the right
83
+ <┴>
84
+ | ░ █ █ █ █
85
+ | ░ ░ █ █ █
86
+ | ░ ░ ░ █ █
87
+
88
+ ATTENTION MASK (sum of causal and sliding window masks):
89
+
90
+ ░ █ █ ░ ░
91
+ ░ ░ █ █ ░
92
+ ░ ░ ░ █ █
93
+
94
+ """
95
+ min_value = torch.finfo(attention_mask.dtype).min
96
+ for i in range(len(cumulative_seqlens_q) - 1):
97
+ seqlen_q = cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]
98
+ seqlen_k = cumulative_seqlens_k[i + 1] - cumulative_seqlens_k[i]
99
+ if seqlen_q < seqlen_k and seqlen_q >= 1:
100
+ causal_diagonal = seqlen_k - seqlen_q + 1
101
+ else:
102
+ causal_diagonal = 1
103
+ query_range = slice(cumulative_seqlens_q[i], cumulative_seqlens_q[i + 1])
104
+ key_range = slice(cumulative_seqlens_k[i], cumulative_seqlens_k[i + 1])
105
+ # Apply causal mask
106
+ minus_inf = torch.full(
107
+ attention_mask[..., query_range, key_range].shape,
108
+ min_value,
109
+ dtype=attention_mask.dtype,
110
+ device=attention_mask.device,
111
+ )
112
+ masked = torch.triu(minus_inf, diagonal=causal_diagonal)
113
+ # Apply sliding window mask if needed
114
+ if sliding_window > 1:
115
+ sliding_diagonal = seqlen_k - seqlen_q - sliding_window
116
+ masked += torch.tril(minus_inf, diagonal=sliding_diagonal)
117
+ # Replace in attention mask
118
+ attention_mask[..., query_range, key_range] = masked
119
+
120
+
121
+ @dataclass
122
+ class PagedAttentionArgs:
123
+ input_ids: torch.Tensor
124
+ attention_mask: Optional[torch.Tensor]
125
+ position_ids: torch.Tensor
126
+ cumulative_seqlens_q: torch.Tensor
127
+ cumulative_seqlens_k: torch.Tensor
128
+ max_seqlen_q: int
129
+ max_seqlen_k: int
130
+ write_index: list[torch.Tensor]
131
+ read_index: list[torch.Tensor]
132
+ logits_indices: torch.Tensor
133
+ cache: PagedAttentionCache
134
+ use_cache: bool = False
135
+
136
+
137
+ # Continuous Batch Processor (Internal Logic)
138
+ @attach_tracer()
139
+ class ContinuousBatchProcessor:
140
+ def __init__(
141
+ self,
142
+ cache: PagedAttentionCache,
143
+ config: PretrainedConfig,
144
+ generation_config: GenerationConfig,
145
+ input_queue: queue.Queue,
146
+ output_queue: queue.Queue,
147
+ stop_event: threading.Event,
148
+ model_device: torch.device,
149
+ model_dtype: torch.dtype,
150
+ scheduler: Scheduler,
151
+ streaming: bool = False,
152
+ manual_eviction: bool = False,
153
+ slice_inputs: bool = True, # TODO: There should be an heuristic to decide on slicing, compile, cuda graphs...
154
+ ) -> None:
155
+ """Initialize the continuous batch processor.
156
+
157
+ Args:
158
+ cache: A [`PagedAttentionCache`] object
159
+ config: The model configuration
160
+ generation_config: The generation configuration
161
+ input_queue: Queue for incoming requests
162
+ output_queue: Queue for outgoing results
163
+ stop_event: Event to signal processing should stop
164
+ model_device: Device for model inputs/outputs
165
+ model_dtype: Data type for model inputs/outputs
166
+ scheduler: The [`Scheduler`] to use
167
+ streaming: Whether to stream tokens as they're generated
168
+ manual_eviction: Whether to manually evict blocks from the cache
169
+ slice_inputs: Whether to slice the inputs to the model
170
+ """
171
+ self.cache = cache
172
+ self.config = config
173
+ self.generation_config = generation_config
174
+ self.input_queue = input_queue
175
+ self.output_queue = output_queue
176
+ self.stop_event = stop_event
177
+ self.model_device = model_device
178
+ self.model_dtype = model_dtype
179
+ self.scheduler = scheduler
180
+ self.streaming = streaming
181
+ self.manual_eviction = manual_eviction
182
+ self.slice_inputs = slice_inputs
183
+
184
+ # Retrieve the size of the sliding window if there is one
185
+ self.sliding_window = 1 if getattr(config, "sliding_window", None) is None else config.sliding_window
186
+
187
+ self.requests_in_batch: list[RequestState] = []
188
+
189
+ # Set up metrics collector
190
+ self.max_batch_tokens = cache.max_batch_tokens
191
+ self.metrics = ContinuousBatchProcessorMetrics(cache.max_batch_tokens)
192
+
193
+ # Setup static tensors
194
+ self.total_query_length = 0
195
+ self.total_key_length = 0
196
+ self.total_batch_size = 0
197
+ self.setup_static_tensors(cache.num_groups)
198
+
199
+ @traced(standalone=True)
200
+ def setup_static_tensors(self, num_groups: int) -> None:
201
+ T = self.max_batch_tokens
202
+ num_pages = self.cache.num_blocks * self.cache.block_size
203
+ self.tensor_metadata = {"dtype": torch.int32, "device": self.model_device}
204
+
205
+ # Some tensors always have the same shape regardless of the model
206
+ self.input_ids = torch.empty((1, T), **self.tensor_metadata)
207
+ self.position_ids = torch.empty((1, T), **self.tensor_metadata)
208
+ self.cumulative_seqlens_q = torch.empty((T + 1,), **self.tensor_metadata)
209
+ self.max_seqlen_q = 0
210
+ self.logits_indices = torch.empty((T,), **self.tensor_metadata)
211
+ self.output_ids = torch.empty((1, T), **self.tensor_metadata)
212
+
213
+ # For some kwargs, we have a dict of tensors with as many items as there are attention types
214
+ layer_types = getattr(self.config, "layer_types", None)
215
+ if layer_types is None:
216
+ sliding_window = getattr(self.config, "sliding_window", 1)
217
+ layer_types = ["full_attention"] if sliding_window in [1, None] else ["sliding_attention"]
218
+ layer_types = list(set(layer_types))
219
+
220
+ self.cumulative_seqlens_k = {
221
+ layer_type: torch.empty((T + 1), **self.tensor_metadata) for layer_type in layer_types
222
+ }
223
+ self.max_seqlen_k = dict.fromkeys(layer_types, 0)
224
+
225
+ if self.return_attention_mask():
226
+ attn_mask_kwargs = {
227
+ "size": (1, 1, T, num_pages + T),
228
+ "dtype": self.model_dtype,
229
+ "device": self.model_device,
230
+ }
231
+ self.attention_mask = {layer_type: torch.empty(**attn_mask_kwargs) for layer_type in layer_types}
232
+ else:
233
+ self.attention_mask = None
234
+
235
+ # For other kwargs, we need a list of tensors with as many tensors as there are groups
236
+ self.write_index_storage = [torch.empty((T,), **self.tensor_metadata) for _ in range(num_groups)]
237
+ self.read_index_storage = [torch.empty((num_pages + T), **self.tensor_metadata) for _ in range(num_groups)]
238
+ # For read index, the +T is because there are -1 for seqlen_q when model uses a sliding window
239
+
240
+ # After allocating empty tensors, we reset them to the right value
241
+ self.reset_static_tensors(full_reset=True)
242
+
243
+ def return_attention_mask(self) -> bool:
244
+ return self.config._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call
245
+
246
+ @traced
247
+ @torch.no_grad()
248
+ def reset_static_tensors(self, full_reset: bool = False):
249
+ """Reset static tensors for the next batch. In between batches, reset only the parts that were used in the last
250
+ batch, but for initialisation, we can reset everything using the (full_reset) flag."""
251
+ # Compute the slice to reset
252
+ if full_reset or not self.slice_inputs:
253
+ q_len = self.write_index_storage[0].size(-1)
254
+ k_len = self.read_index_storage[0].size(-1)
255
+ b_size = self.write_index_storage[0].size(0)
256
+ else:
257
+ q_len = self.total_query_length
258
+ k_len = self.total_key_length
259
+ b_size = self.total_batch_size
260
+
261
+ # Reset the attributes that always have the same shape
262
+ self.input_ids[:, :q_len].zero_()
263
+ self.position_ids[:, :q_len].zero_()
264
+ self.cumulative_seqlens_q[: b_size + 1].zero_()
265
+ self.max_seqlen_q = 0
266
+ self.logits_indices[:q_len].fill_(-1)
267
+ self.output_ids[:, :q_len].fill_(-1)
268
+
269
+ # Reset the attributes that are either tensors or dict of tensors
270
+ for layer_type in self.cumulative_seqlens_k:
271
+ self.cumulative_seqlens_k[layer_type][: b_size + 1].zero_()
272
+ self.max_seqlen_k[layer_type] = 0
273
+ if self.attention_mask is not None:
274
+ self.attention_mask[layer_type][:, :, :q_len, :k_len].fill_(torch.finfo(self.model_dtype).min)
275
+
276
+ # Reset the attributes that are lists of tensors
277
+ for i in range(self.cache.num_groups):
278
+ self.write_index_storage[i][:q_len].fill_(-1)
279
+ self.read_index_storage[i][: q_len + k_len].fill_(-1)
280
+
281
+ def get_model_kwargs(self) -> PagedAttentionArgs:
282
+ """Get model keyword arguments for the current batch."""
283
+ # Compute the slice to return
284
+ q_len = self.total_query_length if self.slice_inputs else self.write_index_storage[0].size(-1)
285
+ b_size = self.total_batch_size if self.slice_inputs else self.cumulative_seqlens_q.size(-1) - 1
286
+
287
+ # Prepare the kwargs, the attributes that are either tensors or dict of tensors are initialized to empty dicts
288
+ kwargs = {
289
+ "input_ids": self.input_ids[:, :q_len],
290
+ "position_ids": self.position_ids[:, :q_len],
291
+ "cu_seq_lens_q": self.cumulative_seqlens_q[: b_size + 1],
292
+ "max_seqlen_q": self.max_seqlen_q,
293
+ "logits_indices": self.logits_indices[:q_len],
294
+ "cu_seq_lens_k": {},
295
+ "max_seqlen_k": {},
296
+ "attention_mask": {},
297
+ "read_index": self.read_index, # slicing is done during building
298
+ "write_index": self.write_index, # slicing is done during building
299
+ "cache": self.cache,
300
+ "use_cache": False,
301
+ }
302
+
303
+ # For the attributes that are dict of tensors, we replace the dict with a tensor if there is only one entry
304
+ layer_types = list(self.cumulative_seqlens_k.keys())
305
+ if len(layer_types) > 1:
306
+ for layer_type, seqlens_k in self.cumulative_seqlens_k.items():
307
+ kwargs["cu_seq_lens_k"][layer_type] = seqlens_k[: b_size + 1]
308
+ kwargs["max_seqlen_k"][layer_type] = self.max_seqlen_k[layer_type]
309
+ if self.attention_mask is not None:
310
+ k_len = seqlens_k[b_size] if self.slice_inputs else self.attention_mask[layer_type].size(-1)
311
+ kwargs["attention_mask"][layer_type] = self.attention_mask[layer_type][..., :q_len, :k_len]
312
+ else:
313
+ layer_type = layer_types[0]
314
+ kwargs["cu_seq_lens_k"] = self.cumulative_seqlens_k[layer_type][: b_size + 1]
315
+ kwargs["max_seqlen_k"] = self.max_seqlen_k[layer_type]
316
+ if self.attention_mask is not None:
317
+ k_len = self.cumulative_seqlens_k[layer_type][b_size]
318
+ k_len = k_len if self.slice_inputs else self.attention_mask[layer_type].size(-1)
319
+ kwargs["attention_mask"] = self.attention_mask[layer_type][..., :q_len, :k_len]
320
+
321
+ if self.attention_mask is None:
322
+ kwargs["attention_mask"] = None
323
+ return kwargs
324
+
325
+ def __repr__(self):
326
+ return (
327
+ f"ContinuousBatchProcessor(input_queue={self.input_queue}, output_queue={self.output_queue}, "
328
+ f"active_requests={self.scheduler.active_requests}, waiting_requests={self.scheduler.waiting_requests})"
329
+ + self.get_model_kwargs().__repr__()
330
+ )
331
+
332
+ @traced
333
+ def _get_new_requests(self):
334
+ """Pull new requests from the input queue and add to waiting list."""
335
+ while not self.input_queue.empty():
336
+ try:
337
+ state = self.input_queue.get_nowait()
338
+ if state is None: # Sentinel value
339
+ continue
340
+ self.scheduler.add_waiting_request(state)
341
+
342
+ except queue.Empty:
343
+ break
344
+ except Exception as e:
345
+ logger.error(f"Error processing new request: {e}", exc_info=True)
346
+ state: RequestState = locals().get("state")
347
+ if state is not None:
348
+ self._handle_request_error(e, state)
349
+
350
+ @traced
351
+ def _handle_request_error(self, error, state: RequestState):
352
+ """Handle general request processing error."""
353
+ state.status = RequestStatus.FAILED
354
+ state.error = str(error)
355
+
356
+ # Include any generated tokens if this is an active request
357
+ if isinstance(state.request_id, str):
358
+ state.static_outputs = self.scheduler.get_active_request_static_outputs(state.request_id)
359
+ else:
360
+ state.static_outputs = []
361
+
362
+ self.metrics.record_request_completion(state.created_time, state.request_id)
363
+ self.output_queue.put(state.to_generation_output())
364
+
365
+ @traced
366
+ def prepare_next_batch(self) -> bool:
367
+ """Prepare tensors and metadata for the next model forward pass. Returns True if there are requests to process,
368
+ False otherwise."""
369
+
370
+ # Get new requests from the queue, stop if there are no pending requests
371
+ self._get_new_requests()
372
+ self.scheduler.clear_cancelled_requests()
373
+ if not self.scheduler.has_pending_requests():
374
+ return False
375
+ self.metrics.record_queue_metrics(len(self.scheduler.active_requests), len(self.scheduler.waiting_requests))
376
+
377
+ # Schedule the next batch of requests, stop if there are no requests in the batch
378
+ self.requests_in_batch = self.scheduler.schedule_batch(self.max_batch_tokens)
379
+ if not self.requests_in_batch:
380
+ return False
381
+ self.metrics.record_batch_metrics(self.requests_in_batch)
382
+
383
+ # Reset the static tensors used for storage
384
+ self.reset_static_tensors() # TODO: with slice_inputs, this might be unnecessary
385
+
386
+ # Prepare accumulators
387
+ self.total_query_length = 0
388
+ self.total_key_length = 0
389
+ self.total_batch_size = 0
390
+
391
+ input_ids = []
392
+ position_ids = []
393
+ cumulative_seqlens_q = [0]
394
+ logits_indices = []
395
+
396
+ if isinstance(self.cumulative_seqlens_k, dict):
397
+ cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
398
+ else:
399
+ cumulative_seqlens_k = [0]
400
+
401
+ read_index = [[] for _ in range(self.cache.num_groups)]
402
+ write_index = [[] for _ in range(self.cache.num_groups)]
403
+
404
+ # Go through all the requests in the batch
405
+ for state in self.requests_in_batch:
406
+ # First we retrieve the lengths related to the request
407
+ past_length = state.position_offset
408
+ query_length = len(state.prompt_ids)
409
+ seqlens_k = self.cache.get_seqlens_k(state.request_id, past_length, query_length)
410
+
411
+ # Then we update the total lengths that are used for slicing
412
+ self.total_query_length += query_length
413
+ # total_key_length is used to slice the keys so we need to take the max of all the key lengths
414
+ self.total_key_length += max(seqlens_k.values())
415
+ self.total_batch_size += 1
416
+ # And the attribute tracking the position in the request object
417
+ state.position_offset += query_length
418
+
419
+ # Then we accumulate for the object used in the kwargs
420
+ input_ids.extend(state.prompt_ids)
421
+ position_ids.extend(range(past_length, past_length + query_length))
422
+ cumulative_seqlens_q.append(cumulative_seqlens_q[-1] + query_length)
423
+ self.max_seqlen_q = max(self.max_seqlen_q, query_length)
424
+
425
+ if not state.remaining_prompt_ids:
426
+ logits_indices.append(cumulative_seqlens_q[-1] - 1)
427
+
428
+ for layer_type, layer_type_seqlen_k in seqlens_k.items():
429
+ cumulative_seqlens_k[layer_type].append(cumulative_seqlens_k[layer_type][-1] + layer_type_seqlen_k)
430
+ self.max_seqlen_k[layer_type] = max(self.max_seqlen_k[layer_type], layer_type_seqlen_k)
431
+
432
+ self.cache.extend_read_indices(state.request_id, past_length, query_length, read_index)
433
+ self.cache.extend_write_indices(state.request_id, past_length, query_length, write_index)
434
+
435
+ # When looping over request is done, we can build the actual tensors
436
+ self._build_tensors(
437
+ input_ids,
438
+ position_ids,
439
+ read_index,
440
+ write_index,
441
+ cumulative_seqlens_q,
442
+ cumulative_seqlens_k,
443
+ logits_indices,
444
+ )
445
+ self.metrics.record_kv_cache_memory_metrics(self.cache)
446
+
447
+ if logger.isEnabledFor(logging.DEBUG):
448
+ if isinstance(self.cumulative_seqlens_k, dict):
449
+ ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
450
+ else:
451
+ ck = cumulative_seqlens_k[-1]
452
+ logger.debug(
453
+ f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, "
454
+ f"Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. "
455
+ f"cum KV: {ck}, free blocks: {self.cache.get_num_free_blocks()}"
456
+ )
457
+ return True
458
+
459
+ @traced
460
+ def _build_tensors(
461
+ self,
462
+ input_ids: list[int],
463
+ position_ids: list[int],
464
+ read_index: list[list[int]],
465
+ write_index: list[list[int]],
466
+ cumulative_seqlens_q: list[int],
467
+ cumulative_seqlens_k: Union[list[int], dict[str, list[int]]],
468
+ logits_indices: list[int],
469
+ ) -> None:
470
+ """Builds the actual tensors for the current batch, by modifying the already allocated tensors in place."""
471
+ to_tensor = partial(torch.tensor, **self.tensor_metadata)
472
+
473
+ # Those kwargs always have the same type regardless of the model
474
+ self.input_ids[:, : len(input_ids)] = to_tensor(input_ids)
475
+ self.position_ids[:, : len(position_ids)] = to_tensor(position_ids)
476
+ self.cumulative_seqlens_q[: len(cumulative_seqlens_q)] = to_tensor(cumulative_seqlens_q)
477
+ self.logits_indices[: len(logits_indices)] = to_tensor(logits_indices)
478
+
479
+ # Those kwargs are either dict of tensors or tensors, so we need to handle both cases
480
+ for layer_type, layer_type_seqlens_k in cumulative_seqlens_k.items():
481
+ self.cumulative_seqlens_k[layer_type][: len(layer_type_seqlens_k)] = to_tensor(layer_type_seqlens_k)
482
+ if self.attention_mask is not None:
483
+ build_attention_mask(
484
+ attention_mask=self.attention_mask[layer_type],
485
+ cumulative_seqlens_q=cumulative_seqlens_q,
486
+ cumulative_seqlens_k=layer_type_seqlens_k,
487
+ sliding_window=self.sliding_window if layer_type == "sliding_attention" else 1,
488
+ )
489
+
490
+ # The index only contain references to the storage tensors, so we update the storage and their references
491
+ self.read_index = []
492
+ self.write_index = []
493
+ for i, group_read_indices, group_write_indices in zip(count(), read_index, write_index):
494
+ # Write in the actual tensors
495
+ self.read_index_storage[i][: len(group_read_indices)] = to_tensor(group_read_indices)
496
+ self.write_index_storage[i][: len(group_write_indices)] = to_tensor(group_write_indices)
497
+ # Slice to the right size
498
+ r = len(group_read_indices) if self.slice_inputs else self.read_index_storage[i].size(-1)
499
+ w = len(group_write_indices) if self.slice_inputs else self.write_index_storage[i].size(-1)
500
+ # Add to the index
501
+ self.read_index.append(self.read_index_storage[i][:r])
502
+ self.write_index.append(self.write_index_storage[i][:w])
503
+
504
+ @traced
505
+ def _sync(self):
506
+ if self.output_ids is not None:
507
+ try:
508
+ out = self.output_ids.tolist()[0] # should be the only sync we do
509
+ except Exception:
510
+ out = [0, 1]
511
+ else:
512
+ out = [0, 0]
513
+ return out
514
+
515
+ @traced
516
+ def _maybe_send_output(self, state: RequestState, token: int):
517
+ """Send output to the queue based on streaming mode and request state."""
518
+ if self.streaming:
519
+ self.output_queue.put(state.to_generation_output())
520
+ elif state.status == RequestStatus.FINISHED:
521
+ self.output_queue.put(state.to_generation_output())
522
+
523
+ @traced
524
+ def update_batch(self):
525
+ """Update request states based on generated tokens."""
526
+ out_tokens = self._sync()
527
+ finished_request_ids = []
528
+ for i, state in enumerate(self.requests_in_batch):
529
+ req_id = state.request_id
530
+ if len(state.remaining_prompt_ids) == 0:
531
+ self.metrics.record_ttft_metric(state.created_time, state.request_id)
532
+ state.status = RequestStatus.DECODING
533
+ token = out_tokens[self.logits_indices[i]]
534
+ state.prompt_ids = [token]
535
+ if state.update_with_token(token):
536
+ self.metrics.record_request_completion(state.created_time, state.request_id)
537
+ self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
538
+ finished_request_ids.append(req_id)
539
+ self._maybe_send_output(state, token)
540
+ elif state.status == RequestStatus.PREFILLING_SPLIT:
541
+ state.status = RequestStatus.SPLIT_PENDING_REMAINDER
542
+ if self.cache.get_num_free_blocks() == 0:
543
+ raise ValueError("No more free blocks")
544
+
545
+ @traced
546
+ def has_pending_requests(self) -> bool:
547
+ """Check if there are any active or waiting requests."""
548
+ return self.scheduler.has_pending_requests()
549
+
550
+ @traced
551
+ def handle_batch_error(self, error):
552
+ """Handle errors during batch processing."""
553
+ failed_reqs = self.requests_in_batch
554
+ for req in failed_reqs:
555
+ self._handle_request_error(error, req)
556
+ self.scheduler.finish_request(req.request_id)
557
+
558
+ @traced
559
+ def fail_all_requests(self, error):
560
+ """Fail all active requests with the given error.
561
+
562
+ Args:
563
+ error: The error to report in the failure message
564
+ """
565
+
566
+ requests = list(self.scheduler.active_requests.values())
567
+ for state in requests:
568
+ self._handle_request_error(error, state)
569
+ self.scheduler.finish_request(state.request_id)
570
+
571
+ # Also fail any requests in the waiting queue
572
+ for req_id in list(self.scheduler.waiting_requests.keys()):
573
+ state = self.scheduler.waiting_requests.pop(req_id)
574
+ self._handle_request_error(error, state)
575
+
576
+ # Clear the ordering queue
577
+ self.scheduler.waiting_requests_order.clear()
578
+
579
+
580
+ # Manager Class (User Interface)
581
+ @attach_tracer()
582
+ class ContinuousBatchingManager:
583
+ """Manager for handling continuous batching of generation requests.
584
+
585
+ This class provides the user interface for submitting generation requests,
586
+ retrieving results, and managing the background generation thread.
587
+ """
588
+
589
+ def __init__(
590
+ self,
591
+ model,
592
+ generation_config: GenerationConfig,
593
+ manual_eviction: bool = False,
594
+ max_queue_size=0,
595
+ streaming: bool = True,
596
+ slice_inputs: bool = True,
597
+ ):
598
+ """Initialize the continuous batching manager.
599
+
600
+ Args:
601
+ model: The language model for generation
602
+ generation_config: Configuration for generation parameters
603
+ max_queue_size: Maximum size of the request queue (0 = unlimited)
604
+ streaming: Whether to stream tokens as they are generated
605
+ """
606
+ self.model = model.eval()
607
+ generation_config = model.generation_config if generation_config is None else generation_config
608
+ self.generation_config = generation_config
609
+ self.input_queue = queue.Queue(maxsize=max_queue_size)
610
+ self.output_queue = queue.Queue()
611
+ self.stop_event = threading.Event()
612
+ self.streaming = streaming
613
+ self.log_prob_generation = getattr(generation_config, "log_prob_generation", False)
614
+ self._generation_thread = None
615
+ self._request_counter = 0
616
+ self._request_lock = threading.Lock()
617
+ self.model.generation_config.top_p = None
618
+ self.do_sample = getattr(generation_config, "do_sample", True)
619
+ self.logit_processor = self.model._get_logits_processor(generation_config)
620
+ self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", False) # TODO: same as do_sample
621
+ self.profile = getattr(generation_config, "profile", False)
622
+ self.manual_eviction = manual_eviction
623
+ self.batch_processor: Optional[ContinuousBatchProcessor] = None
624
+ self.slice_inputs = slice_inputs
625
+
626
+ if self.use_cuda_graph:
627
+ raise NotImplementedError("Cuda graphs are not supported yet")
628
+
629
+ @traced
630
+ def start(self):
631
+ """Start the background generation thread."""
632
+ if self._generation_thread is not None and self._generation_thread.is_alive():
633
+ logger.warning("Manager thread is already running.")
634
+ return
635
+
636
+ self._result_queue = queue.Queue()
637
+ self._generation_thread = threading.Thread(target=self._run_generation_loop)
638
+ self._generation_thread.start()
639
+
640
+ def is_running(self):
641
+ """Check if the background generation thread is running."""
642
+ return self._generation_thread is not None and self._generation_thread.is_alive()
643
+
644
+ def stop(self, block: bool = False, timeout: Optional[float] = None):
645
+ """Signal the background thread to stop.
646
+
647
+ Args:
648
+ block: Whether to wait for the thread to stop
649
+ timeout: Maximum time to wait for the thread to stop
650
+ """
651
+ if self._generation_thread is None:
652
+ logger.warning("Manager not started.")
653
+ return
654
+
655
+ if not self.stop_event.is_set():
656
+ self.stop_event.set()
657
+ logger.info("Stopping continuous batching manager...")
658
+
659
+ if block:
660
+ self.join(timeout)
661
+
662
+ def join(self, timeout: Optional[float] = None):
663
+ """Wait for the background thread to finish.
664
+
665
+ Args:
666
+ timeout: Maximum time to wait for the thread to stop
667
+ """
668
+ if self._generation_thread is not None:
669
+ self._generation_thread.join(timeout=timeout)
670
+ if self._generation_thread.is_alive():
671
+ logger.warning("Generation thread did not exit after join timeout.")
672
+ else:
673
+ logger.info("Continuous Batching Manager stopped.")
674
+ self._generation_thread = None
675
+
676
+ def add_request(
677
+ self, input_ids: list[int], request_id: Optional[str] = None, max_new_tokens: Optional[int] = None
678
+ ) -> str:
679
+ """Add a new generation request to the queue.
680
+
681
+ Args:
682
+ input_ids: Input token IDs to use as prompt
683
+ request_id: Optional custom request ID (auto-generated if None)
684
+ **kwargs: Additional generation parameters
685
+
686
+ Returns:
687
+ str: The request ID
688
+ """
689
+ if request_id is None:
690
+ with self._request_lock:
691
+ request_id = f"req_{self._request_counter}"
692
+ self._request_counter += 1
693
+
694
+ max_new_tokens = self.generation_config.max_new_tokens if max_new_tokens is None else max_new_tokens
695
+
696
+ # NOTE: do we want to handle a case when the user wants token ids returned instead of decoded text?
697
+ state = RequestState(
698
+ request_id=request_id,
699
+ prompt_ids=list(input_ids),
700
+ full_prompt_ids=list(input_ids),
701
+ max_new_tokens=max_new_tokens,
702
+ eos_token_id=self.generation_config.eos_token_id,
703
+ )
704
+
705
+ # Use block=True with timeout to handle backpressure if queue is full
706
+ self.input_queue.put(state, block=True, timeout=10) # XXX: pass timeout as fn arg?
707
+ logger.debug(f"Added request {request_id} to queue.")
708
+ return request_id
709
+
710
+ def add_requests(self, inputs: list[list[int]], **kwargs):
711
+ for input_ids in inputs:
712
+ self.add_request(input_ids, **kwargs)
713
+
714
+ def cancel_request(self, request_id: str):
715
+ """Cancel a request by its ID.
716
+
717
+ Args:
718
+ request_id: The ID of the request to cancel
719
+ """
720
+ if self.batch_processor is not None:
721
+ self.batch_processor.scheduler.set_request_cancellation(request_id)
722
+
723
+ def get_result(self, request_id=None, timeout=None) -> Optional[GenerationOutput]:
724
+ """Retrieve one result from the output queue.
725
+
726
+ Args:
727
+ timeout: Maximum time to wait for a result
728
+
729
+ Returns:
730
+ Optional[GenerationOutput]: The result data or None if timeout
731
+ """
732
+ if self._generation_thread is None and self.output_queue.empty():
733
+ return None
734
+ try:
735
+ result = self.output_queue.get(block=True, timeout=timeout)
736
+ if request_id is not None and result.request_id != request_id:
737
+ self.output_queue.put(result)
738
+ return None
739
+ logger.debug(f"Retrieved result for request {result.request_id}")
740
+ return result
741
+ except queue.Empty:
742
+ return None
743
+
744
+ def __iter__(self):
745
+ """Iterate over results as they become available."""
746
+ while self._generation_thread is not None and self._generation_thread.is_alive():
747
+ result = self.get_result(timeout=0.1)
748
+ if result is not None:
749
+ yield result
750
+
751
+ def request_id_iter(self, request_id):
752
+ """Iterate over results matching a specific request id as they become available."""
753
+ request_cancelled = False
754
+ while self._generation_thread is not None and self._generation_thread.is_alive() and not request_cancelled:
755
+ result = self.get_result(request_id=request_id, timeout=0.1)
756
+ if result is not None:
757
+ yield result
758
+ if self.batch_processor is not None:
759
+ request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id)
760
+
761
+ @staticmethod
762
+ def supported_attention_implementations() -> set[str]:
763
+ return {"eager_paged", "sdpa_paged", "flash_attention_2"}
764
+
765
+ @staticmethod
766
+ def default_attention_implementation() -> str:
767
+ return "sdpa_paged"
768
+
769
+ @traced
770
+ def warmup(self, batch_processor):
771
+ stream = torch.cuda.Stream(device=self.model.device)
772
+ stream.wait_stream(torch.cuda.current_stream())
773
+ with torch.cuda.stream(stream):
774
+ # Warmup the model with a dummy forward pass
775
+ self._generation_step(batch_processor)
776
+ torch.cuda.current_stream().wait_stream(stream)
777
+
778
+ self.graph = torch.cuda.CUDAGraph()
779
+ with torch.cuda.graph(self.graph, stream=stream):
780
+ self._generation_step(batch_processor)
781
+
782
+ @traced
783
+ # @torch.compile
784
+ def _generation_step(self, batch_processor: ContinuousBatchProcessor):
785
+ """Perform a single generation step. This is cuda graphed"""
786
+ batch_data = batch_processor.get_model_kwargs()
787
+ with torch.no_grad():
788
+ logits = self._model_forward(batch_data)
789
+ if self.log_prob_generation:
790
+ batch_processor.output_probs.copy_(logits) # TODO
791
+ probs = self._process_logit(batch_data, logits)
792
+ self._sample(batch_processor, probs)
793
+
794
+ @traced(span_name="model_forward")
795
+ def _model_forward(self, batch_data):
796
+ return self.model(**batch_data).logits
797
+
798
+ @traced(span_name="logit_processing")
799
+ def _process_logit(self, batch_data, logits):
800
+ # Pass continuous batching context to logits processor if it supports it. TODO we should find a way to make this a little bit cleaner!
801
+ if hasattr(self.logit_processor, "set_continuous_batching_context"):
802
+ self.logit_processor.set_continuous_batching_context(
803
+ batch_data["logits_indices"], batch_data["cu_seq_lens_q"]
804
+ )
805
+
806
+ # Handle shape compatibility: logit processors expect 2D tensors [batch_size, vocab_size]
807
+ # but continuous batching always produces 3D tensors [batch_size, seq_len, vocab_size]
808
+ batch_size, seq_len, vocab_size = logits.shape
809
+ logits_2d = logits.view(batch_size * seq_len, vocab_size)
810
+ input_ids_2d = batch_data["input_ids"].view(batch_size * seq_len)
811
+
812
+ # Process with 2D tensors
813
+ processed_logits_2d = self.logit_processor(input_ids_2d, logits_2d)
814
+
815
+ # Reshape back to 3D
816
+ return processed_logits_2d.view(batch_size, seq_len, vocab_size)
817
+
818
+ @traced(span_name="sampling")
819
+ def _sample(self, batch_processor: ContinuousBatchProcessor, probs):
820
+ if self.do_sample: # sample
821
+ probs = nn.functional.softmax(probs, dim=-1)
822
+ # probs[0] has shape [seq_len, vocab_size], multinomial returns [seq_len, 1]
823
+ next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(-1) # Now [seq_len]
824
+ # Add batch dimension back to match argmax output
825
+ next_tokens = next_tokens.unsqueeze(0) # Now [1, seq_len]
826
+ else:
827
+ next_tokens = torch.argmax(probs, dim=-1) # Already [1, seq_len]
828
+
829
+ tokens = next_tokens.size(1) # Get seq_len dimension
830
+ batch_processor.output_ids[:, :tokens].copy_(next_tokens)
831
+
832
+ def _run_generation_loop(self):
833
+ """Main processing loop running in the background thread."""
834
+ batch_processor = None
835
+ try:
836
+ ref_time = perf_counter()
837
+ paged_attention_cache = PagedAttentionCache(
838
+ self.model.config,
839
+ self.generation_config,
840
+ self.model.device,
841
+ self.model.dtype,
842
+ tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting
843
+ )
844
+ logger.debug(f"PagedAttentionCache created in {perf_counter() - ref_time} seconds")
845
+
846
+ scheduler = None
847
+ if hasattr(self.generation_config, "scheduler"):
848
+ scheduler = SCHEDULER_MAPPING.get(self.generation_config.scheduler, None)
849
+ if scheduler is None:
850
+ logger.warning(f"Scheduler '{scheduler}' not found. Defaulting to FIFO.")
851
+ scheduler = FIFOScheduler
852
+ else:
853
+ # Default to fifo
854
+ scheduler = FIFOScheduler
855
+
856
+ ref_time = perf_counter()
857
+ batch_processor = ContinuousBatchProcessor(
858
+ paged_attention_cache,
859
+ self.model.config,
860
+ self.generation_config,
861
+ self.input_queue,
862
+ self.output_queue,
863
+ self.stop_event,
864
+ self.model.device,
865
+ self.model.dtype,
866
+ scheduler(paged_attention_cache, self.manual_eviction),
867
+ self.streaming,
868
+ self.manual_eviction,
869
+ slice_inputs=self.slice_inputs,
870
+ )
871
+ self.batch_processor = batch_processor
872
+ self.current_batch = 0
873
+ logger.debug(f"batch_processor created in {perf_counter() - ref_time} seconds")
874
+ while (not self.stop_event.is_set()) or batch_processor.has_pending_requests():
875
+ self._inner_generation_loop(batch_processor)
876
+ self.current_batch += 1
877
+
878
+ except Exception as e:
879
+ logger.error(f"Error in generation loop: {e}", exc_info=True)
880
+ self._handle_critical_error(e, batch_processor)
881
+ finally:
882
+ logger.info("Generation loop finished.")
883
+
884
+ @traced(span_name="generation_loop")
885
+ def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor):
886
+ if torch.cuda.is_available():
887
+ torch.cuda.synchronize()
888
+ if not batch_processor.prepare_next_batch():
889
+ return
890
+ if logger.level <= logging.DEBUG:
891
+ device, total, reserved, allocated = get_device_and_memory_breakdown()
892
+ logger.debug(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}")
893
+ if torch.cuda.is_available() and self.use_cuda_graph:
894
+ if self.current_batch == 0:
895
+ self.warmup(batch_processor)
896
+ elif hasattr(self, "graph"):
897
+ try:
898
+ self._graph_replay()
899
+ except Exception as e:
900
+ logger.error(f"Model forward pass failed: {e}", exc_info=True)
901
+ batch_processor.handle_batch_error(e)
902
+ return
903
+ else:
904
+ self._generation_step(batch_processor)
905
+ else:
906
+ self._generation_step(batch_processor)
907
+ if torch.cuda.is_available():
908
+ torch.cuda.synchronize()
909
+ batch_processor.update_batch()
910
+
911
+ @traced(span_name="graph_replay")
912
+ def _graph_replay(self):
913
+ self.graph.replay()
914
+
915
+ @traced
916
+ def _handle_critical_error(self, error, batch_processor: Optional[ContinuousBatchProcessor]):
917
+ """Handle critical errors that terminate the generation loop."""
918
+ # Signal stop
919
+ self.stop_event.set()
920
+
921
+ # Fail pending requests in input queue
922
+ try:
923
+ while True:
924
+ req_data = self.input_queue.get_nowait()
925
+ if batch_processor is not None:
926
+ batch_processor._handle_request_error(error, req_data)
927
+ except queue.Empty:
928
+ pass
929
+
930
+ # Fail active requests
931
+ if batch_processor is not None:
932
+ batch_processor.fail_all_requests(error)
933
+
934
+ @traced
935
+ def evict_request_from_cache(self, request_id: str):
936
+ """Evict a request from the cache. It is assumed that the request is already finished."""
937
+ if not self.manual_eviction:
938
+ raise RuntimeError("Manual eviction is not enabled for this manager.")
939
+ if self.batch_processor is not None:
940
+ self.batch_processor.scheduler.finish_request(request_id)
941
+
942
+
943
+ class ContinuousMixin:
944
+ """Mixin class for models to add continuous batching capabilities."""
945
+
946
+ def init_continuous_batching(
947
+ self,
948
+ generation_config: Optional[GenerationConfig] = None,
949
+ manual_eviction: bool = False,
950
+ max_queue_size: int = 0,
951
+ streaming: bool = False,
952
+ slice_inputs: bool = True,
953
+ ) -> ContinuousBatchingManager:
954
+ """Initialize a manager for continuous batching inference.
955
+
956
+ Args:
957
+ generation_config: Custom generation configuration
958
+ max_queue_size: Maximum size of the input request queue
959
+ streaming: Whether to stream tokens as they are generated
960
+
961
+ Returns:
962
+ `ContinuousBatchingManager`: The manager instance to add requests and retrieve results.
963
+ """
964
+ if not hasattr(self, "config") or not hasattr(self, "device") or not hasattr(self, "dtype"):
965
+ raise AttributeError("Model must have 'config', 'device', and 'dtype' attributes.")
966
+
967
+ gen_config = generation_config if generation_config is not None else self.generation_config
968
+ if gen_config is None:
969
+ raise ValueError("A GenerationConfig must be provided or set in the model.")
970
+
971
+ if gen_config.eos_token_id is None:
972
+ logger.warning("`eos_token_id` not set in GenerationConfig. Setting to -1 (disabled).")
973
+ gen_config.eos_token_id = -1
974
+
975
+ # Create and return the manager
976
+ return ContinuousBatchingManager(
977
+ model=self,
978
+ generation_config=gen_config,
979
+ manual_eviction=manual_eviction,
980
+ max_queue_size=max_queue_size,
981
+ streaming=streaming,
982
+ slice_inputs=slice_inputs,
983
+ )
984
+
985
+ @traced
986
+ @torch.inference_mode()
987
+ def generate_batch(
988
+ self,
989
+ inputs: list[list[int]],
990
+ generation_config: Optional[GenerationConfig] = None,
991
+ progress_bar: bool = True,
992
+ slice_inputs: bool = True,
993
+ **kwargs,
994
+ ) -> list[list[int]]:
995
+ """Generate sequences for a batch of prompts using continuous batching.
996
+
997
+ Args:
998
+ inputs: List of input token sequences (prompts)
999
+ generation_config: Optional generation configuration
1000
+ **kwargs: Additional generation parameters
1001
+
1002
+ Returns:
1003
+ `list[list[int]]`: A list containing the generated sequences (including prompt tokens
1004
+ if not handled otherwise) for each input prompt, in the same order.
1005
+ Returns an empty list `[]` for requests that failed.
1006
+ """
1007
+ if not inputs:
1008
+ return []
1009
+ if logger.getEffectiveLevel() <= logging.DEBUG:
1010
+ logger.warning("Progress bar is disabled when logger level is less than DEBUG")
1011
+ progress_bar = False
1012
+
1013
+ # Initialize manager with the batch inputs
1014
+ manager = self.init_continuous_batching(generation_config=generation_config, slice_inputs=slice_inputs)
1015
+ manager.start()
1016
+ results = {}
1017
+ num_requests = len(inputs)
1018
+ try:
1019
+ from tqdm.contrib.logging import logging_redirect_tqdm
1020
+
1021
+ with logging_redirect_tqdm([logger]):
1022
+ with tqdm(
1023
+ total=num_requests,
1024
+ disable=(not progress_bar),
1025
+ desc=f"Solving {num_requests} requests",
1026
+ unit="request",
1027
+ ) as pbar:
1028
+ manager.add_requests(inputs, **kwargs)
1029
+ finished_count = 0
1030
+ while finished_count < num_requests:
1031
+ result = manager.get_result(timeout=1)
1032
+ if result:
1033
+ req_id = result.request_id
1034
+ if result.status == RequestStatus.FINISHED:
1035
+ results[req_id] = result
1036
+ finished_count += 1
1037
+ pbar.update(1)
1038
+ else:
1039
+ if not manager.is_running():
1040
+ logger.error("Generation thread terminated unexpectedly.")
1041
+ break
1042
+
1043
+ except Exception as e:
1044
+ logger.error(f"Error during batch generation: {e}", exc_info=True)
1045
+ finally:
1046
+ manager.stop(block=True, timeout=5.0)
1047
+ return results
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/requests.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import time
16
+ from dataclasses import dataclass, field
17
+ from enum import Enum
18
+ from typing import Optional
19
+
20
+ import torch
21
+
22
+ from ...utils.logging import logging
23
+ from ...utils.metrics import traced
24
+
25
+
26
+ # We centralize the logger here to coordinate between logging and progress bar
27
+ logger = logging.getLogger("ContinuousBatchingLogger")
28
+ # logger.setLevel(logging.INFO)
29
+
30
+
31
+ def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
32
+ if torch.cuda.is_available():
33
+ device = torch.device("cuda")
34
+ torch.cuda.empty_cache()
35
+ torch.cuda.synchronize()
36
+ total_memory = torch.cuda.get_device_properties(device).total_memory
37
+ reserved_memory = torch.cuda.memory_reserved(device)
38
+ allocated_memory = torch.cuda.memory_allocated(device)
39
+ elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
40
+ device = torch.device("mps")
41
+ # MPS memory reporting (PyTorch 2.0+)
42
+ total_memory = torch.mps.driver_allocated_memory()
43
+ allocated_memory = total_memory - torch.mps.recommended_max_memory()
44
+ reserved_memory = 0 # MPS does not track reserved separately
45
+ else:
46
+ device = torch.device("cpu")
47
+ total_memory = None
48
+ reserved_memory = 0
49
+ allocated_memory = 0
50
+ return device, total_memory, reserved_memory, allocated_memory
51
+
52
+
53
+ class RequestStatus(Enum):
54
+ """Status of a generation request through its lifecycle."""
55
+
56
+ PENDING = "pending"
57
+ PREFILLING = "prefilling"
58
+ PREFILLING_SPLIT = "prefilling_split"
59
+ SPLIT_PENDING_REMAINDER = "split_pending_remainder"
60
+ DECODING = "decoding"
61
+ FINISHED = "finished"
62
+ FAILED = "failed"
63
+
64
+
65
+ @dataclass
66
+ class GenerationOutput:
67
+ """Tracks the output of a generation request.
68
+
69
+ Attributes:
70
+ request_id (str): The ID of the generation request.
71
+ prompt_ids (list[int]): The IDs of the prompt tokens.
72
+ generated_tokens (list[int]): The generated tokens.
73
+ logprobs (list[float]): The log probabilities of the generated tokens.
74
+ error (Optional[str]): Any error message associated with the request. When None, the request was successful.
75
+ status (RequestStatus): The status of the request.
76
+ created_time (float): The time the request was created.
77
+ """
78
+
79
+ request_id: str
80
+ prompt_ids: list[int] = field(default_factory=list)
81
+ generated_tokens: list[int] = field(default_factory=list)
82
+ logprobs: list[float] = field(default_factory=list)
83
+ error: Optional[str] = None
84
+ status: RequestStatus = RequestStatus.PENDING
85
+ created_time: float = field(default_factory=time.time)
86
+
87
+
88
+ @dataclass
89
+ class RequestState:
90
+ """Tracks the state of a generation request through its lifecycle.
91
+
92
+ Attributes:
93
+ request_id (str): The ID of the generation request.
94
+ full_prompt_ids (list[int] | None): The tokens IDs of the full prompt.
95
+ prompt_ids (list[int] | None): The tokens IDs currently being processed.
96
+ remaining_prompt_ids (list[int]): The tokens IDs remaining to be processed (for split requests).
97
+ static_outputs (list[int]): The generated tokens.
98
+ allocated_blocks (int): The number of blocks allocated to the request.
99
+ position_offset (int): The current position in the sequence for position_ids.
100
+ status (RequestStatus): The status of the request: can be one of PENDING, PREFILLING, PREFILLING_SPLIT,
101
+ SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED
102
+ max_new_tokens (int): The maximum number of new tokens to generate.
103
+ eos_token_id (int): The ID of the end-of-sequence token.
104
+ created_time (float): The time the request was created.
105
+ error (Optional[str]): Any error message associated with the request. When None, has had no error yet.
106
+ """
107
+
108
+ # Required fields
109
+ request_id: str
110
+ full_prompt_ids: Optional[list[int]] = None # Full initial prompt
111
+ prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated)
112
+ remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process
113
+ static_outputs: list[int] = field(default_factory=list) # Generated tokens
114
+ allocated_blocks: int = 0 # Number of blocks allocated to the request
115
+ position_offset: int = 0 # Current position in the sequence for position_ids
116
+ _status: RequestStatus = RequestStatus.PENDING # Status of the request, hidden behind a property
117
+ max_new_tokens: int = 20 # Maximum number of new tokens to generate
118
+ eos_token_id: int = -1 # ID of the end-of-sequence token
119
+ created_time: float = field(default_factory=time.time) # Time the request was created
120
+ error: Optional[str] = None # Error message if the request failed
121
+ lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished)
122
+
123
+ @property
124
+ def status(self) -> RequestStatus:
125
+ return self._status
126
+
127
+ @status.setter
128
+ def status(self, value: RequestStatus):
129
+ if self._status == RequestStatus.PENDING:
130
+ self.lifespan = (time.time(), -1)
131
+ elif value == RequestStatus.FINISHED:
132
+ self.lifespan = (self.lifespan[0], time.time())
133
+ self.log_end_of_request()
134
+ self._status = value
135
+
136
+ def log_end_of_request(self):
137
+ prefill_len = len(self.full_prompt_ids)
138
+ decode_len = self.generated_len()
139
+ start_time = self.lifespan[0] - self.created_time
140
+ end_time = self.lifespan[1] - self.created_time
141
+ logger.info(
142
+ f"Request {self.request_id} finished: {prefill_len = } {decode_len = } {start_time = } {end_time = }"
143
+ )
144
+
145
+ def current_len(self) -> int:
146
+ """Get the current length of the sequence (prompt + generated tokens)."""
147
+ return self.position_offset
148
+
149
+ def generated_len(self) -> int:
150
+ """Get the number of tokens generated so far."""
151
+ return len(self.static_outputs)
152
+
153
+ # TODO: this logic seems one token off, check it out
154
+ @traced
155
+ def update_with_token(self, token_id: int) -> bool:
156
+ """Update the request with a newly generated token and check for completion.
157
+
158
+ Args:
159
+ token_id: The token ID to add to the output sequence
160
+
161
+ Returns:
162
+ bool: True if the request is now complete, False otherwise
163
+ """
164
+ # Only update if we're in decoding state
165
+ if self.status != RequestStatus.DECODING:
166
+ return False
167
+
168
+ is_eos = token_id == self.eos_token_id and self.eos_token_id != -1
169
+ is_max_len = self.generated_len() >= self.max_new_tokens
170
+
171
+ # Only add the token if we're not finishing due to max length
172
+ # (EOS tokens should still be added to the output)
173
+ if not (is_max_len and not is_eos):
174
+ self.static_outputs.extend([token_id])
175
+
176
+ if is_eos or is_max_len:
177
+ self.status = RequestStatus.FINISHED
178
+ return True
179
+ return False
180
+
181
+ def __repr__(self):
182
+ msg = [
183
+ f"request_id={self.request_id}",
184
+ f"status={self._status}",
185
+ f"out_tokens={self.generated_len()}",
186
+ f"query_length={len(self.prompt_ids)}",
187
+ f"remaining_tokens={len(self.remaining_prompt_ids)}",
188
+ f"kv_length={self.position_offset}",
189
+ f"full_prompt_length={len(self.full_prompt_ids)}",
190
+ f"allocated_blocks={self.allocated_blocks}",
191
+ f"generated_tokens={self.static_outputs}",
192
+ ]
193
+ return "RequestState(\n\t" + ",\n\t".join(msg) + "\n)"
194
+
195
+ def to_generation_output(self):
196
+ """Convert the request state to a GenerationOutput object."""
197
+ return GenerationOutput(
198
+ request_id=self.request_id,
199
+ prompt_ids=self.full_prompt_ids,
200
+ status=self.status,
201
+ generated_tokens=self.static_outputs,
202
+ logprobs=[],
203
+ error=self.error,
204
+ )
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/scheduler.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import threading
16
+ from abc import ABC, abstractmethod
17
+ from collections import deque
18
+
19
+ from ...utils.metrics import attach_tracer, traced
20
+ from .cache import PagedAttentionCache
21
+ from .requests import RequestState, RequestStatus
22
+
23
+
24
+ class Scheduler(ABC):
25
+ """
26
+ Abstract base class for scheduling requests in the continuous batch processor. Schedulers manage the lifecycle of
27
+ requests from when they are added to the waiting queue to when they are scheduled for processing. Different
28
+ schedulers implement different strategies for prioritizing and batching requests.
29
+ """
30
+
31
+ def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False):
32
+ self.active_requests: dict[str, RequestState] = {}
33
+ self.waiting_requests: dict[str, RequestState] = {}
34
+ self.waiting_requests_order: deque[str] = deque()
35
+ self.cache = cache
36
+ self.retain_cache_on_finish = retain_cache_on_finish
37
+ self._cancellation_lock = threading.Lock()
38
+ self._requests_to_cancel: set[str] = set()
39
+
40
+ @traced
41
+ def add_waiting_request(self, state: RequestState):
42
+ """Adds a request to the waiting list."""
43
+ if self.retain_cache_on_finish and state.request_id in self.active_requests:
44
+ old_state = self.active_requests.pop(state.request_id)
45
+ state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :] # XXX: check for indexing error?
46
+ state.allocated_blocks = old_state.allocated_blocks
47
+ state.position_offset = old_state.position_offset
48
+ self.waiting_requests[state.request_id] = state
49
+ self.waiting_requests_order.append(state.request_id)
50
+
51
+ @abstractmethod
52
+ def schedule_batch(self, token_budget: int) -> list[RequestState]:
53
+ """Schedules requests for the next batch based on available token budget. This method selects which requests
54
+ should be processed in the current batch, considering the token budget and the scheduler's prioritization rules.
55
+ The token_budget is the maximum number of tokens that can be processed in this batch."""
56
+ pass
57
+
58
+ @traced
59
+ def has_pending_requests(self) -> bool:
60
+ """Checks if there are requests ready to be processed."""
61
+ return len(self.active_requests) or len(self.waiting_requests)
62
+
63
+ @traced
64
+ def finish_request(self, request_id: str, evict_from_cache: bool = True):
65
+ """Completes processing of a request and optionally frees its allocated cache blocks. This method is called
66
+ when a request has finished generation or encountered an error.
67
+ """
68
+ if evict_from_cache:
69
+ self.cache.free_blocks(request_id)
70
+ if request_id in self.active_requests:
71
+ del self.active_requests[request_id]
72
+
73
+ @traced
74
+ def get_active_request_static_outputs(self, request_id: str) -> list[int]:
75
+ """Gets generated tokens for an active request."""
76
+ if request_id in self.active_requests:
77
+ return self.active_requests[request_id].static_outputs
78
+ return []
79
+
80
+ @traced
81
+ def set_request_cancellation(self, request_id: str):
82
+ """Marks a request for cancellation."""
83
+ with self._cancellation_lock:
84
+ self._requests_to_cancel.add(request_id)
85
+
86
+ @traced
87
+ def clear_cancelled_requests(self):
88
+ """Remove all cancelled requests from active and waiting queues."""
89
+ with self._cancellation_lock:
90
+ for request_id in self._requests_to_cancel:
91
+ if request_id in self.active_requests:
92
+ del self.active_requests[request_id]
93
+ if request_id in self.waiting_requests:
94
+ del self.waiting_requests[request_id]
95
+ if request_id in self.waiting_requests_order:
96
+ self.waiting_requests_order.remove(request_id)
97
+ self.cache.free_blocks(request_id)
98
+ self._requests_to_cancel = set()
99
+
100
+ @traced
101
+ def request_is_cancelled(self, request_id: str) -> bool:
102
+ """Checks if a request has been cancelled or removed."""
103
+ return request_id in self._requests_to_cancel or (
104
+ request_id not in self.active_requests and request_id not in self.waiting_requests
105
+ )
106
+
107
+ @traced
108
+ def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int) -> bool:
109
+ """Allocate additional cache blocks for a request if the currently allocated blocks are insufficient to
110
+ accommodate the next tokens. It calculates how many blocks are needed based on the request's current
111
+ cache occupancy and the number of tokens to be processed. The allocation itself is done by the CacheAllocator
112
+ objects. Returns a boolean indicating if the allocation was successful or not.
113
+ """
114
+ # 1. we check that the occupancy is less than the requested length
115
+ # 2. we allocate enough blocks to cover the requested length
116
+ current_len = state.current_len()
117
+ occupancy = state.allocated_blocks * self.cache.block_size - current_len
118
+ if occupancy < len_next_tokens or state.allocated_blocks == 0:
119
+ blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
120
+ allocated = self.cache.allocate_blocks(blocks_needed, state.request_id)
121
+ if allocated is None:
122
+ return False
123
+ state.allocated_blocks += allocated
124
+ return True
125
+
126
+ @traced(span_name="prepare_request")
127
+ def _prepare_request_for_processing(
128
+ self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str]
129
+ ):
130
+ """Prepares a request for processing in the current batch."""
131
+ request_tokens = (
132
+ state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids
133
+ )
134
+ if len(request_tokens) < token_budget:
135
+ # Can process the entire prompt/remainder
136
+ if state.status == RequestStatus.PENDING:
137
+ self.active_requests[state.request_id] = state
138
+ state.status = RequestStatus.PREFILLING
139
+ request_ids_to_remove_from_waiting.add(state.request_id)
140
+ elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
141
+ state.status = RequestStatus.PREFILLING
142
+ state.prompt_ids = state.remaining_prompt_ids
143
+ state.remaining_prompt_ids = []
144
+ else:
145
+ # Need to split the request
146
+ if state.status == RequestStatus.PENDING:
147
+ self.active_requests[state.request_id] = state
148
+ state.status = RequestStatus.PREFILLING_SPLIT
149
+ request_ids_to_remove_from_waiting.add(state.request_id)
150
+ elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
151
+ state.status = RequestStatus.PREFILLING_SPLIT
152
+ state.remaining_prompt_ids = request_tokens[token_budget:]
153
+ state.prompt_ids = request_tokens[:token_budget]
154
+
155
+
156
+ @attach_tracer()
157
+ class FIFOScheduler(Scheduler):
158
+ """This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
159
+ prefilling requests. Additionally, it includes a safety margin mechanism to prevent cache exhaustion. By default,
160
+ when 80% of the cache is full, new requests will not be scheduled to prioritize decoding active requests."""
161
+
162
+ def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False, safety_margin: float = 0.2):
163
+ """Initializes the FIFO scheduler. The safety margin is the percentage of free blocks under which we stop
164
+ scheduling new prefill requests, so safety_margin = 0.1 means that when there is less than 10% of free blocks,
165
+ or equivalently when more than 90% of blocks are already allocated, we stop scheduling new prefill requests.
166
+ """
167
+ super().__init__(cache, retain_cache_on_finish)
168
+ self.safety_margin = safety_margin
169
+
170
+ @traced
171
+ def schedule_batch(self, token_budget: int) -> list[RequestState]:
172
+ priority_states: list[RequestState] = []
173
+ second_priority_states: list[RequestState] = []
174
+ scheduled_requests = []
175
+
176
+ for state in self.active_requests.values():
177
+ if state.status == RequestStatus.DECODING:
178
+ priority_states.append(state)
179
+ if state.status in [RequestStatus.SPLIT_PENDING_REMAINDER, RequestStatus.PREFILLING_SPLIT]:
180
+ second_priority_states.append(state)
181
+
182
+ # Add waiting requests to second priority
183
+ for req_id in self.waiting_requests_order:
184
+ second_priority_states.append(self.waiting_requests[req_id])
185
+
186
+ candidates = priority_states + second_priority_states
187
+ request_ids_to_remove_from_waiting = set()
188
+ safety_margins = self.safety_margin * self.cache.num_blocks
189
+
190
+ for state in candidates:
191
+ # If we are out the safety margin, we only accept decoding requests or the first prefill request
192
+ num_free_blocks = self.cache.get_num_free_blocks()
193
+ outside_safety_margin = num_free_blocks < safety_margins
194
+ if outside_safety_margin and scheduled_requests and state.status != RequestStatus.DECODING:
195
+ break
196
+
197
+ self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
198
+ request_len = len(state.prompt_ids)
199
+ if not self._allocate_blocks_if_needed(
200
+ state, len(state.prompt_ids)
201
+ ): # don't schedule if we can't allocate blocks
202
+ if len(self.cache._free_blocks) == 0:
203
+ break
204
+ continue
205
+
206
+ @traced
207
+ def _add_to_scheduled_requests(state: RequestState):
208
+ scheduled_requests.append(state)
209
+
210
+ _add_to_scheduled_requests(state)
211
+
212
+ token_budget -= request_len
213
+
214
+ @traced
215
+ def _remove_from_waiting_requests(state: RequestState):
216
+ req_id = state.request_id
217
+ if req_id in self.waiting_requests:
218
+ del self.waiting_requests[req_id]
219
+ request_ids_to_remove_from_waiting.add(req_id)
220
+
221
+ _remove_from_waiting_requests(state)
222
+
223
+ if token_budget == 0:
224
+ break
225
+
226
+ self.waiting_requests_order = deque(
227
+ [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting]
228
+ )
229
+
230
+ return scheduled_requests
231
+
232
+
233
+ # FIXME: prioritize adding from waiting reqs before scheduling `RequestStatus.DECODING` when cache space allows it
234
+ @attach_tracer()
235
+ class PrefillFirstScheduler(Scheduler):
236
+ """Scheduler that prioritizes split prefill requests over decoding requests. This scheduler ensures that split
237
+ prefill requests (which are continuations of partially processed prompts) are completed before processing new
238
+ decoding requests."""
239
+
240
+ @traced
241
+ def schedule_batch(self, token_budget: int) -> list[RequestState]:
242
+ priority_states: list[RequestState] = []
243
+ second_priority_states: list[RequestState] = []
244
+ scheduled_requests = []
245
+
246
+ for state in self.active_requests.values():
247
+ # XXX: when cache is full, state can stay on `PREFILLING_SPLIT` so we need to take those into account
248
+ if state.status in [RequestStatus.PREFILLING_SPLIT, RequestStatus.SPLIT_PENDING_REMAINDER]:
249
+ priority_states.append(state)
250
+ elif state.status == RequestStatus.DECODING:
251
+ second_priority_states.append(state)
252
+
253
+ for req_id in self.waiting_requests_order:
254
+ second_priority_states.append(self.waiting_requests[req_id])
255
+
256
+ candidates = priority_states + second_priority_states
257
+
258
+ request_ids_to_remove_from_waiting = set()
259
+
260
+ for state in candidates:
261
+ self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
262
+ request_len = len(state.prompt_ids)
263
+ if not self._allocate_blocks_if_needed(
264
+ state, len(state.prompt_ids)
265
+ ): # don't schedule if we can't allocate blocks
266
+ if len(self.cache._free_blocks) == 0:
267
+ break
268
+ continue
269
+
270
+ @traced
271
+ def _add_to_scheduled_requests(state: RequestState):
272
+ scheduled_requests.append(state)
273
+
274
+ _add_to_scheduled_requests(state)
275
+
276
+ token_budget -= request_len
277
+
278
+ @traced
279
+ def _remove_from_waiting_requests(state: RequestState):
280
+ req_id = state.request_id
281
+ if req_id in self.waiting_requests:
282
+ del self.waiting_requests[req_id]
283
+ request_ids_to_remove_from_waiting.add(req_id)
284
+
285
+ _remove_from_waiting_requests(state)
286
+
287
+ if token_budget == 0:
288
+ break
289
+
290
+ self.waiting_requests_order = deque(
291
+ [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting]
292
+ )
293
+
294
+ return scheduled_requests
295
+
296
+
297
+ SCHEDULER_MAPPING = {
298
+ "fifo": FIFOScheduler,
299
+ "prefill_first": PrefillFirstScheduler,
300
+ }
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_albert import *
22
+ from .modeling_albert import *
23
+ from .modeling_flax_albert import *
24
+ from .modeling_tf_albert import *
25
+ from .tokenization_albert import *
26
+ from .tokenization_albert_fast import *
27
+ else:
28
+ import sys
29
+
30
+ _file = globals()["__file__"]
31
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/configuration_albert.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ALBERT model configuration"""
17
+
18
+ from collections import OrderedDict
19
+ from collections.abc import Mapping
20
+
21
+ from ...configuration_utils import PretrainedConfig
22
+ from ...onnx import OnnxConfig
23
+
24
+
25
+ class AlbertConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`AlbertModel`] or a [`TFAlbertModel`]. It is used
28
+ to instantiate an ALBERT model according to the specified arguments, defining the model architecture. Instantiating
29
+ a configuration with the defaults will yield a similar configuration to that of the ALBERT
30
+ [albert/albert-xxlarge-v2](https://huggingface.co/albert/albert-xxlarge-v2) architecture.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 30000):
37
+ Vocabulary size of the ALBERT model. Defines the number of different tokens that can be represented by the
38
+ `inputs_ids` passed when calling [`AlbertModel`] or [`TFAlbertModel`].
39
+ embedding_size (`int`, *optional*, defaults to 128):
40
+ Dimensionality of vocabulary embeddings.
41
+ hidden_size (`int`, *optional*, defaults to 4096):
42
+ Dimensionality of the encoder layers and the pooler layer.
43
+ num_hidden_layers (`int`, *optional*, defaults to 12):
44
+ Number of hidden layers in the Transformer encoder.
45
+ num_hidden_groups (`int`, *optional*, defaults to 1):
46
+ Number of groups for the hidden layers, parameters in the same group are shared.
47
+ num_attention_heads (`int`, *optional*, defaults to 64):
48
+ Number of attention heads for each attention layer in the Transformer encoder.
49
+ intermediate_size (`int`, *optional*, defaults to 16384):
50
+ The dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
51
+ inner_group_num (`int`, *optional*, defaults to 1):
52
+ The number of inner repetition of attention and ffn.
53
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu_new"`):
54
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
55
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
56
+ hidden_dropout_prob (`float`, *optional*, defaults to 0):
57
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
58
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0):
59
+ The dropout ratio for the attention probabilities.
60
+ max_position_embeddings (`int`, *optional*, defaults to 512):
61
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
62
+ (e.g., 512 or 1024 or 2048).
63
+ type_vocab_size (`int`, *optional*, defaults to 2):
64
+ The vocabulary size of the `token_type_ids` passed when calling [`AlbertModel`] or [`TFAlbertModel`].
65
+ initializer_range (`float`, *optional*, defaults to 0.02):
66
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
68
+ The epsilon used by the layer normalization layers.
69
+ classifier_dropout_prob (`float`, *optional*, defaults to 0.1):
70
+ The dropout ratio for attached classifiers.
71
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
72
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
73
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
74
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
75
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
76
+ with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
77
+ pad_token_id (`int`, *optional*, defaults to 0):
78
+ Padding token id.
79
+ bos_token_id (`int`, *optional*, defaults to 2):
80
+ Beginning of stream token id.
81
+ eos_token_id (`int`, *optional*, defaults to 3):
82
+ End of stream token id.
83
+
84
+ Examples:
85
+
86
+ ```python
87
+ >>> from transformers import AlbertConfig, AlbertModel
88
+
89
+ >>> # Initializing an ALBERT-xxlarge style configuration
90
+ >>> albert_xxlarge_configuration = AlbertConfig()
91
+
92
+ >>> # Initializing an ALBERT-base style configuration
93
+ >>> albert_base_configuration = AlbertConfig(
94
+ ... hidden_size=768,
95
+ ... num_attention_heads=12,
96
+ ... intermediate_size=3072,
97
+ ... )
98
+
99
+ >>> # Initializing a model (with random weights) from the ALBERT-base style configuration
100
+ >>> model = AlbertModel(albert_xxlarge_configuration)
101
+
102
+ >>> # Accessing the model configuration
103
+ >>> configuration = model.config
104
+ ```"""
105
+
106
+ model_type = "albert"
107
+
108
+ def __init__(
109
+ self,
110
+ vocab_size=30000,
111
+ embedding_size=128,
112
+ hidden_size=4096,
113
+ num_hidden_layers=12,
114
+ num_hidden_groups=1,
115
+ num_attention_heads=64,
116
+ intermediate_size=16384,
117
+ inner_group_num=1,
118
+ hidden_act="gelu_new",
119
+ hidden_dropout_prob=0,
120
+ attention_probs_dropout_prob=0,
121
+ max_position_embeddings=512,
122
+ type_vocab_size=2,
123
+ initializer_range=0.02,
124
+ layer_norm_eps=1e-12,
125
+ classifier_dropout_prob=0.1,
126
+ position_embedding_type="absolute",
127
+ pad_token_id=0,
128
+ bos_token_id=2,
129
+ eos_token_id=3,
130
+ **kwargs,
131
+ ):
132
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
133
+
134
+ self.vocab_size = vocab_size
135
+ self.embedding_size = embedding_size
136
+ self.hidden_size = hidden_size
137
+ self.num_hidden_layers = num_hidden_layers
138
+ self.num_hidden_groups = num_hidden_groups
139
+ self.num_attention_heads = num_attention_heads
140
+ self.inner_group_num = inner_group_num
141
+ self.hidden_act = hidden_act
142
+ self.intermediate_size = intermediate_size
143
+ self.hidden_dropout_prob = hidden_dropout_prob
144
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
145
+ self.max_position_embeddings = max_position_embeddings
146
+ self.type_vocab_size = type_vocab_size
147
+ self.initializer_range = initializer_range
148
+ self.layer_norm_eps = layer_norm_eps
149
+ self.classifier_dropout_prob = classifier_dropout_prob
150
+ self.position_embedding_type = position_embedding_type
151
+
152
+
153
+ # Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert
154
+ class AlbertOnnxConfig(OnnxConfig):
155
+ @property
156
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
157
+ if self.task == "multiple-choice":
158
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
159
+ else:
160
+ dynamic_axis = {0: "batch", 1: "sequence"}
161
+ return OrderedDict(
162
+ [
163
+ ("input_ids", dynamic_axis),
164
+ ("attention_mask", dynamic_axis),
165
+ ("token_type_ids", dynamic_axis),
166
+ ]
167
+ )
168
+
169
+
170
+ __all__ = ["AlbertConfig", "AlbertOnnxConfig"]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_albert.py ADDED
@@ -0,0 +1,1349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch ALBERT model."""
16
+
17
+ import math
18
+ import os
19
+ from dataclasses import dataclass
20
+ from typing import Optional, Union
21
+
22
+ import torch
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from ...activations import ACT2FN
27
+ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
28
+ from ...modeling_outputs import (
29
+ BaseModelOutput,
30
+ BaseModelOutputWithPooling,
31
+ MaskedLMOutput,
32
+ MultipleChoiceModelOutput,
33
+ QuestionAnsweringModelOutput,
34
+ SequenceClassifierOutput,
35
+ TokenClassifierOutput,
36
+ )
37
+ from ...modeling_utils import PreTrainedModel
38
+ from ...pytorch_utils import (
39
+ apply_chunking_to_forward,
40
+ find_pruneable_heads_and_indices,
41
+ prune_linear_layer,
42
+ )
43
+ from ...utils import ModelOutput, auto_docstring, logging
44
+ from .configuration_albert import AlbertConfig
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
51
+ """Load tf checkpoints in a pytorch model."""
52
+ try:
53
+ import re
54
+
55
+ import numpy as np
56
+ import tensorflow as tf
57
+ except ImportError:
58
+ logger.error(
59
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
60
+ "https://www.tensorflow.org/install/ for installation instructions."
61
+ )
62
+ raise
63
+ tf_path = os.path.abspath(tf_checkpoint_path)
64
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
65
+ # Load weights from TF model
66
+ init_vars = tf.train.list_variables(tf_path)
67
+ names = []
68
+ arrays = []
69
+ for name, shape in init_vars:
70
+ logger.info(f"Loading TF weight {name} with shape {shape}")
71
+ array = tf.train.load_variable(tf_path, name)
72
+ names.append(name)
73
+ arrays.append(array)
74
+
75
+ for name, array in zip(names, arrays):
76
+ print(name)
77
+
78
+ for name, array in zip(names, arrays):
79
+ original_name = name
80
+
81
+ # If saved from the TF HUB module
82
+ name = name.replace("module/", "")
83
+
84
+ # Renaming and simplifying
85
+ name = name.replace("ffn_1", "ffn")
86
+ name = name.replace("bert/", "albert/")
87
+ name = name.replace("attention_1", "attention")
88
+ name = name.replace("transform/", "")
89
+ name = name.replace("LayerNorm_1", "full_layer_layer_norm")
90
+ name = name.replace("LayerNorm", "attention/LayerNorm")
91
+ name = name.replace("transformer/", "")
92
+
93
+ # The feed forward layer had an 'intermediate' step which has been abstracted away
94
+ name = name.replace("intermediate/dense/", "")
95
+ name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")
96
+
97
+ # ALBERT attention was split between self and output which have been abstracted away
98
+ name = name.replace("/output/", "/")
99
+ name = name.replace("/self/", "/")
100
+
101
+ # The pooler is a linear layer
102
+ name = name.replace("pooler/dense", "pooler")
103
+
104
+ # The classifier was simplified to predictions from cls/predictions
105
+ name = name.replace("cls/predictions", "predictions")
106
+ name = name.replace("predictions/attention", "predictions")
107
+
108
+ # Naming was changed to be more explicit
109
+ name = name.replace("embeddings/attention", "embeddings")
110
+ name = name.replace("inner_group_", "albert_layers/")
111
+ name = name.replace("group_", "albert_layer_groups/")
112
+
113
+ # Classifier
114
+ if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
115
+ name = "classifier/" + name
116
+
117
+ # No ALBERT model currently handles the next sentence prediction task
118
+ if "seq_relationship" in name:
119
+ name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
120
+ name = name.replace("weights", "weight")
121
+
122
+ name = name.split("/")
123
+
124
+ # Ignore the gradients applied by the LAMB/ADAM optimizers.
125
+ if (
126
+ "adam_m" in name
127
+ or "adam_v" in name
128
+ or "AdamWeightDecayOptimizer" in name
129
+ or "AdamWeightDecayOptimizer_1" in name
130
+ or "global_step" in name
131
+ ):
132
+ logger.info(f"Skipping {'/'.join(name)}")
133
+ continue
134
+
135
+ pointer = model
136
+ for m_name in name:
137
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
138
+ scope_names = re.split(r"_(\d+)", m_name)
139
+ else:
140
+ scope_names = [m_name]
141
+
142
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
143
+ pointer = getattr(pointer, "weight")
144
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
145
+ pointer = getattr(pointer, "bias")
146
+ elif scope_names[0] == "output_weights":
147
+ pointer = getattr(pointer, "weight")
148
+ elif scope_names[0] == "squad":
149
+ pointer = getattr(pointer, "classifier")
150
+ else:
151
+ try:
152
+ pointer = getattr(pointer, scope_names[0])
153
+ except AttributeError:
154
+ logger.info(f"Skipping {'/'.join(name)}")
155
+ continue
156
+ if len(scope_names) >= 2:
157
+ num = int(scope_names[1])
158
+ pointer = pointer[num]
159
+
160
+ if m_name[-11:] == "_embeddings":
161
+ pointer = getattr(pointer, "weight")
162
+ elif m_name == "kernel":
163
+ array = np.transpose(array)
164
+ try:
165
+ if pointer.shape != array.shape:
166
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
167
+ except ValueError as e:
168
+ e.args += (pointer.shape, array.shape)
169
+ raise
170
+ print(f"Initialize PyTorch weight {name} from {original_name}")
171
+ pointer.data = torch.from_numpy(array)
172
+
173
+ return model
174
+
175
+
176
+ class AlbertEmbeddings(nn.Module):
177
+ """
178
+ Construct the embeddings from word, position and token_type embeddings.
179
+ """
180
+
181
+ def __init__(self, config: AlbertConfig):
182
+ super().__init__()
183
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
184
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
185
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
186
+
187
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
188
+ # any TensorFlow checkpoint file
189
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
190
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
191
+
192
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
193
+ self.register_buffer(
194
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
195
+ )
196
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
197
+ self.register_buffer(
198
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
199
+ )
200
+
201
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
202
+ def forward(
203
+ self,
204
+ input_ids: Optional[torch.LongTensor] = None,
205
+ token_type_ids: Optional[torch.LongTensor] = None,
206
+ position_ids: Optional[torch.LongTensor] = None,
207
+ inputs_embeds: Optional[torch.FloatTensor] = None,
208
+ past_key_values_length: int = 0,
209
+ ) -> torch.Tensor:
210
+ if input_ids is not None:
211
+ input_shape = input_ids.size()
212
+ else:
213
+ input_shape = inputs_embeds.size()[:-1]
214
+
215
+ seq_length = input_shape[1]
216
+
217
+ if position_ids is None:
218
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
219
+
220
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
221
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
222
+ # issue #5664
223
+ if token_type_ids is None:
224
+ if hasattr(self, "token_type_ids"):
225
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
226
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
227
+ token_type_ids = buffered_token_type_ids_expanded
228
+ else:
229
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
230
+
231
+ if inputs_embeds is None:
232
+ inputs_embeds = self.word_embeddings(input_ids)
233
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
234
+
235
+ embeddings = inputs_embeds + token_type_embeddings
236
+ if self.position_embedding_type == "absolute":
237
+ position_embeddings = self.position_embeddings(position_ids)
238
+ embeddings += position_embeddings
239
+ embeddings = self.LayerNorm(embeddings)
240
+ embeddings = self.dropout(embeddings)
241
+ return embeddings
242
+
243
+
244
+ class AlbertAttention(nn.Module):
245
+ def __init__(self, config: AlbertConfig):
246
+ super().__init__()
247
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
248
+ raise ValueError(
249
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
250
+ f"heads ({config.num_attention_heads}"
251
+ )
252
+
253
+ self.num_attention_heads = config.num_attention_heads
254
+ self.hidden_size = config.hidden_size
255
+ self.attention_head_size = config.hidden_size // config.num_attention_heads
256
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
257
+
258
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
259
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
260
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
261
+
262
+ self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
263
+ self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
264
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
265
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
266
+ self.pruned_heads = set()
267
+
268
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
269
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
270
+ self.max_position_embeddings = config.max_position_embeddings
271
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
272
+
273
+ def prune_heads(self, heads: list[int]) -> None:
274
+ if len(heads) == 0:
275
+ return
276
+ heads, index = find_pruneable_heads_and_indices(
277
+ heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
278
+ )
279
+
280
+ # Prune linear layers
281
+ self.query = prune_linear_layer(self.query, index)
282
+ self.key = prune_linear_layer(self.key, index)
283
+ self.value = prune_linear_layer(self.value, index)
284
+ self.dense = prune_linear_layer(self.dense, index, dim=1)
285
+
286
+ # Update hyper params and store pruned heads
287
+ self.num_attention_heads = self.num_attention_heads - len(heads)
288
+ self.all_head_size = self.attention_head_size * self.num_attention_heads
289
+ self.pruned_heads = self.pruned_heads.union(heads)
290
+
291
+ def forward(
292
+ self,
293
+ hidden_states: torch.Tensor,
294
+ attention_mask: Optional[torch.FloatTensor] = None,
295
+ head_mask: Optional[torch.FloatTensor] = None,
296
+ output_attentions: bool = False,
297
+ ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
298
+ batch_size, seq_length, _ = hidden_states.shape
299
+ query_layer = self.query(hidden_states)
300
+ key_layer = self.key(hidden_states)
301
+ value_layer = self.value(hidden_states)
302
+ query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
303
+ 1, 2
304
+ )
305
+ key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
306
+ value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
307
+ 1, 2
308
+ )
309
+
310
+ # Take the dot product between "query" and "key" to get the raw attention scores.
311
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
312
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
313
+
314
+ if attention_mask is not None:
315
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
316
+ attention_scores = attention_scores + attention_mask
317
+
318
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
319
+ seq_length = hidden_states.size()[1]
320
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
321
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
322
+ distance = position_ids_l - position_ids_r
323
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
324
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
325
+
326
+ if self.position_embedding_type == "relative_key":
327
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
328
+ attention_scores = attention_scores + relative_position_scores
329
+ elif self.position_embedding_type == "relative_key_query":
330
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
331
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
332
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
333
+
334
+ # Normalize the attention scores to probabilities.
335
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
336
+
337
+ # This is actually dropping out entire tokens to attend to, which might
338
+ # seem a bit unusual, but is taken from the original Transformer paper.
339
+ attention_probs = self.attention_dropout(attention_probs)
340
+
341
+ # Mask heads if we want to
342
+ if head_mask is not None:
343
+ attention_probs = attention_probs * head_mask
344
+
345
+ context_layer = torch.matmul(attention_probs, value_layer)
346
+ context_layer = context_layer.transpose(2, 1).flatten(2)
347
+
348
+ projected_context_layer = self.dense(context_layer)
349
+ projected_context_layer_dropout = self.output_dropout(projected_context_layer)
350
+ layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
351
+ return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
352
+
353
+
354
+ class AlbertSdpaAttention(AlbertAttention):
355
+ def __init__(self, config):
356
+ super().__init__(config)
357
+ self.dropout_prob = config.attention_probs_dropout_prob
358
+
359
+ def forward(
360
+ self,
361
+ hidden_states: torch.Tensor,
362
+ attention_mask: Optional[torch.FloatTensor] = None,
363
+ head_mask: Optional[torch.FloatTensor] = None,
364
+ output_attentions: bool = False,
365
+ ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
366
+ if self.position_embedding_type != "absolute" or output_attentions:
367
+ logger.warning(
368
+ "AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
369
+ "non-absolute `position_embedding_type` or `output_attentions=True` . Falling back to "
370
+ "the eager attention implementation, but specifying the eager implementation will be required from "
371
+ "Transformers version v5.0.0 onwards. This warning can be removed using the argument "
372
+ '`attn_implementation="eager"` when loading the model.'
373
+ )
374
+ return super().forward(hidden_states, attention_mask, output_attentions=output_attentions)
375
+
376
+ batch_size, seq_len, _ = hidden_states.size()
377
+ query_layer = (
378
+ self.query(hidden_states)
379
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
380
+ .transpose(1, 2)
381
+ )
382
+ key_layer = (
383
+ self.key(hidden_states)
384
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
385
+ .transpose(1, 2)
386
+ )
387
+ value_layer = (
388
+ self.value(hidden_states)
389
+ .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
390
+ .transpose(1, 2)
391
+ )
392
+
393
+ attention_output = torch.nn.functional.scaled_dot_product_attention(
394
+ query=query_layer,
395
+ key=key_layer,
396
+ value=value_layer,
397
+ attn_mask=attention_mask,
398
+ dropout_p=self.dropout_prob if self.training else 0.0,
399
+ is_causal=False,
400
+ )
401
+
402
+ attention_output = attention_output.transpose(1, 2)
403
+ attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size)
404
+
405
+ projected_context_layer = self.dense(attention_output)
406
+ projected_context_layer_dropout = self.output_dropout(projected_context_layer)
407
+ layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
408
+ return (layernormed_context_layer,)
409
+
410
+
411
+ ALBERT_ATTENTION_CLASSES = {
412
+ "eager": AlbertAttention,
413
+ "sdpa": AlbertSdpaAttention,
414
+ }
415
+
416
+
417
+ class AlbertLayer(nn.Module):
418
+ def __init__(self, config: AlbertConfig):
419
+ super().__init__()
420
+
421
+ self.config = config
422
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
423
+ self.seq_len_dim = 1
424
+ self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
425
+ self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config)
426
+ self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
427
+ self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
428
+ self.activation = ACT2FN[config.hidden_act]
429
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
430
+
431
+ def forward(
432
+ self,
433
+ hidden_states: torch.Tensor,
434
+ attention_mask: Optional[torch.FloatTensor] = None,
435
+ head_mask: Optional[torch.FloatTensor] = None,
436
+ output_attentions: bool = False,
437
+ output_hidden_states: bool = False,
438
+ ) -> tuple[torch.Tensor, torch.Tensor]:
439
+ attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
440
+
441
+ ffn_output = apply_chunking_to_forward(
442
+ self.ff_chunk,
443
+ self.chunk_size_feed_forward,
444
+ self.seq_len_dim,
445
+ attention_output[0],
446
+ )
447
+ hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
448
+
449
+ return (hidden_states,) + attention_output[1:] # add attentions if we output them
450
+
451
+ def ff_chunk(self, attention_output: torch.Tensor) -> torch.Tensor:
452
+ ffn_output = self.ffn(attention_output)
453
+ ffn_output = self.activation(ffn_output)
454
+ ffn_output = self.ffn_output(ffn_output)
455
+ return ffn_output
456
+
457
+
458
+ class AlbertLayerGroup(nn.Module):
459
+ def __init__(self, config: AlbertConfig):
460
+ super().__init__()
461
+
462
+ self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
463
+
464
+ def forward(
465
+ self,
466
+ hidden_states: torch.Tensor,
467
+ attention_mask: Optional[torch.FloatTensor] = None,
468
+ head_mask: Optional[torch.FloatTensor] = None,
469
+ output_attentions: bool = False,
470
+ output_hidden_states: bool = False,
471
+ ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
472
+ layer_hidden_states = ()
473
+ layer_attentions = ()
474
+
475
+ for layer_index, albert_layer in enumerate(self.albert_layers):
476
+ layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
477
+ hidden_states = layer_output[0]
478
+
479
+ if output_attentions:
480
+ layer_attentions = layer_attentions + (layer_output[1],)
481
+
482
+ if output_hidden_states:
483
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
484
+
485
+ outputs = (hidden_states,)
486
+ if output_hidden_states:
487
+ outputs = outputs + (layer_hidden_states,)
488
+ if output_attentions:
489
+ outputs = outputs + (layer_attentions,)
490
+ return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
491
+
492
+
493
+ class AlbertTransformer(nn.Module):
494
+ def __init__(self, config: AlbertConfig):
495
+ super().__init__()
496
+
497
+ self.config = config
498
+ self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
499
+ self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
500
+
501
+ def forward(
502
+ self,
503
+ hidden_states: torch.Tensor,
504
+ attention_mask: Optional[torch.FloatTensor] = None,
505
+ head_mask: Optional[torch.FloatTensor] = None,
506
+ output_attentions: bool = False,
507
+ output_hidden_states: bool = False,
508
+ return_dict: bool = True,
509
+ ) -> Union[BaseModelOutput, tuple]:
510
+ hidden_states = self.embedding_hidden_mapping_in(hidden_states)
511
+
512
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
513
+ all_attentions = () if output_attentions else None
514
+
515
+ head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask
516
+
517
+ for i in range(self.config.num_hidden_layers):
518
+ # Number of layers in a hidden group
519
+ layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
520
+
521
+ # Index of the hidden group
522
+ group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
523
+
524
+ layer_group_output = self.albert_layer_groups[group_idx](
525
+ hidden_states,
526
+ attention_mask,
527
+ head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
528
+ output_attentions,
529
+ output_hidden_states,
530
+ )
531
+ hidden_states = layer_group_output[0]
532
+
533
+ if output_attentions:
534
+ all_attentions = all_attentions + layer_group_output[-1]
535
+
536
+ if output_hidden_states:
537
+ all_hidden_states = all_hidden_states + (hidden_states,)
538
+
539
+ if not return_dict:
540
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
541
+ return BaseModelOutput(
542
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
543
+ )
544
+
545
+
546
+ @auto_docstring
547
+ class AlbertPreTrainedModel(PreTrainedModel):
548
+ config: AlbertConfig
549
+ load_tf_weights = load_tf_weights_in_albert
550
+ base_model_prefix = "albert"
551
+ _supports_sdpa = True
552
+
553
+ def _init_weights(self, module):
554
+ """Initialize the weights."""
555
+ if isinstance(module, nn.Linear):
556
+ # Slightly different from the TF version which uses truncated_normal for initialization
557
+ # cf https://github.com/pytorch/pytorch/pull/5617
558
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
559
+ if module.bias is not None:
560
+ module.bias.data.zero_()
561
+ elif isinstance(module, nn.Embedding):
562
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
563
+ if module.padding_idx is not None:
564
+ module.weight.data[module.padding_idx].zero_()
565
+ elif isinstance(module, nn.LayerNorm):
566
+ module.bias.data.zero_()
567
+ module.weight.data.fill_(1.0)
568
+ elif isinstance(module, AlbertMLMHead):
569
+ module.bias.data.zero_()
570
+
571
+
572
+ @dataclass
573
+ @auto_docstring(
574
+ custom_intro="""
575
+ Output type of [`AlbertForPreTraining`].
576
+ """
577
+ )
578
+ class AlbertForPreTrainingOutput(ModelOutput):
579
+ r"""
580
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
581
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
582
+ (classification) loss.
583
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
584
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
585
+ sop_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
586
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
587
+ before SoftMax).
588
+ """
589
+
590
+ loss: Optional[torch.FloatTensor] = None
591
+ prediction_logits: Optional[torch.FloatTensor] = None
592
+ sop_logits: Optional[torch.FloatTensor] = None
593
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
594
+ attentions: Optional[tuple[torch.FloatTensor]] = None
595
+
596
+
597
+ @auto_docstring
598
+ class AlbertModel(AlbertPreTrainedModel):
599
+ config: AlbertConfig
600
+ base_model_prefix = "albert"
601
+
602
+ def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True):
603
+ r"""
604
+ add_pooling_layer (bool, *optional*, defaults to `True`):
605
+ Whether to add a pooling layer
606
+ """
607
+ super().__init__(config)
608
+
609
+ self.config = config
610
+ self.embeddings = AlbertEmbeddings(config)
611
+ self.encoder = AlbertTransformer(config)
612
+ if add_pooling_layer:
613
+ self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
614
+ self.pooler_activation = nn.Tanh()
615
+ else:
616
+ self.pooler = None
617
+ self.pooler_activation = None
618
+
619
+ self.attn_implementation = config._attn_implementation
620
+ self.position_embedding_type = config.position_embedding_type
621
+
622
+ # Initialize weights and apply final processing
623
+ self.post_init()
624
+
625
+ def get_input_embeddings(self) -> nn.Embedding:
626
+ return self.embeddings.word_embeddings
627
+
628
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
629
+ self.embeddings.word_embeddings = value
630
+
631
+ def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
632
+ """
633
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
634
+ a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT
635
+ model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers.
636
+
637
+ These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
638
+ while [2,3] correspond to the two inner groups of the second hidden layer.
639
+
640
+ Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more
641
+ information about head pruning
642
+ """
643
+ for layer, heads in heads_to_prune.items():
644
+ group_idx = int(layer / self.config.inner_group_num)
645
+ inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
646
+ self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
647
+
648
+ @auto_docstring
649
+ def forward(
650
+ self,
651
+ input_ids: Optional[torch.LongTensor] = None,
652
+ attention_mask: Optional[torch.FloatTensor] = None,
653
+ token_type_ids: Optional[torch.LongTensor] = None,
654
+ position_ids: Optional[torch.LongTensor] = None,
655
+ head_mask: Optional[torch.FloatTensor] = None,
656
+ inputs_embeds: Optional[torch.FloatTensor] = None,
657
+ output_attentions: Optional[bool] = None,
658
+ output_hidden_states: Optional[bool] = None,
659
+ return_dict: Optional[bool] = None,
660
+ ) -> Union[BaseModelOutputWithPooling, tuple]:
661
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
662
+ output_hidden_states = (
663
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
664
+ )
665
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
666
+
667
+ if input_ids is not None and inputs_embeds is not None:
668
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
669
+ elif input_ids is not None:
670
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
671
+ input_shape = input_ids.size()
672
+ elif inputs_embeds is not None:
673
+ input_shape = inputs_embeds.size()[:-1]
674
+ else:
675
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
676
+
677
+ batch_size, seq_length = input_shape
678
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
679
+
680
+ if attention_mask is None:
681
+ attention_mask = torch.ones(input_shape, device=device)
682
+ if token_type_ids is None:
683
+ if hasattr(self.embeddings, "token_type_ids"):
684
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
685
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
686
+ token_type_ids = buffered_token_type_ids_expanded
687
+ else:
688
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
689
+
690
+ embedding_output = self.embeddings(
691
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
692
+ )
693
+
694
+ use_sdpa_attention_mask = (
695
+ self.attn_implementation == "sdpa"
696
+ and self.position_embedding_type == "absolute"
697
+ and head_mask is None
698
+ and not output_attentions
699
+ )
700
+
701
+ if use_sdpa_attention_mask:
702
+ extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
703
+ attention_mask, embedding_output.dtype, tgt_len=seq_length
704
+ )
705
+ else:
706
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
707
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
708
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
709
+
710
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
711
+
712
+ encoder_outputs = self.encoder(
713
+ embedding_output,
714
+ extended_attention_mask,
715
+ head_mask=head_mask,
716
+ output_attentions=output_attentions,
717
+ output_hidden_states=output_hidden_states,
718
+ return_dict=return_dict,
719
+ )
720
+
721
+ sequence_output = encoder_outputs[0]
722
+
723
+ pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
724
+
725
+ if not return_dict:
726
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
727
+
728
+ return BaseModelOutputWithPooling(
729
+ last_hidden_state=sequence_output,
730
+ pooler_output=pooled_output,
731
+ hidden_states=encoder_outputs.hidden_states,
732
+ attentions=encoder_outputs.attentions,
733
+ )
734
+
735
+
736
+ @auto_docstring(
737
+ custom_intro="""
738
+ Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
739
+ `sentence order prediction (classification)` head.
740
+ """
741
+ )
742
+ class AlbertForPreTraining(AlbertPreTrainedModel):
743
+ _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
744
+
745
+ def __init__(self, config: AlbertConfig):
746
+ super().__init__(config)
747
+
748
+ self.albert = AlbertModel(config)
749
+ self.predictions = AlbertMLMHead(config)
750
+ self.sop_classifier = AlbertSOPHead(config)
751
+
752
+ # Initialize weights and apply final processing
753
+ self.post_init()
754
+
755
+ def get_output_embeddings(self) -> nn.Linear:
756
+ return self.predictions.decoder
757
+
758
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
759
+ self.predictions.decoder = new_embeddings
760
+
761
+ def get_input_embeddings(self) -> nn.Embedding:
762
+ return self.albert.embeddings.word_embeddings
763
+
764
+ @auto_docstring
765
+ def forward(
766
+ self,
767
+ input_ids: Optional[torch.LongTensor] = None,
768
+ attention_mask: Optional[torch.FloatTensor] = None,
769
+ token_type_ids: Optional[torch.LongTensor] = None,
770
+ position_ids: Optional[torch.LongTensor] = None,
771
+ head_mask: Optional[torch.FloatTensor] = None,
772
+ inputs_embeds: Optional[torch.FloatTensor] = None,
773
+ labels: Optional[torch.LongTensor] = None,
774
+ sentence_order_label: Optional[torch.LongTensor] = None,
775
+ output_attentions: Optional[bool] = None,
776
+ output_hidden_states: Optional[bool] = None,
777
+ return_dict: Optional[bool] = None,
778
+ ) -> Union[AlbertForPreTrainingOutput, tuple]:
779
+ r"""
780
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
781
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
782
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
783
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
784
+ sentence_order_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
785
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
786
+ (see `input_ids` docstring) Indices should be in `[0, 1]`. `0` indicates original order (sequence A, then
787
+ sequence B), `1` indicates switched order (sequence B, then sequence A).
788
+
789
+ Example:
790
+
791
+ ```python
792
+ >>> from transformers import AutoTokenizer, AlbertForPreTraining
793
+ >>> import torch
794
+
795
+ >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
796
+ >>> model = AlbertForPreTraining.from_pretrained("albert/albert-base-v2")
797
+
798
+ >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)
799
+ >>> # Batch size 1
800
+ >>> outputs = model(input_ids)
801
+
802
+ >>> prediction_logits = outputs.prediction_logits
803
+ >>> sop_logits = outputs.sop_logits
804
+ ```"""
805
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
806
+
807
+ outputs = self.albert(
808
+ input_ids,
809
+ attention_mask=attention_mask,
810
+ token_type_ids=token_type_ids,
811
+ position_ids=position_ids,
812
+ head_mask=head_mask,
813
+ inputs_embeds=inputs_embeds,
814
+ output_attentions=output_attentions,
815
+ output_hidden_states=output_hidden_states,
816
+ return_dict=return_dict,
817
+ )
818
+
819
+ sequence_output, pooled_output = outputs[:2]
820
+
821
+ prediction_scores = self.predictions(sequence_output)
822
+ sop_scores = self.sop_classifier(pooled_output)
823
+
824
+ total_loss = None
825
+ if labels is not None and sentence_order_label is not None:
826
+ loss_fct = CrossEntropyLoss()
827
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
828
+ sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
829
+ total_loss = masked_lm_loss + sentence_order_loss
830
+
831
+ if not return_dict:
832
+ output = (prediction_scores, sop_scores) + outputs[2:]
833
+ return ((total_loss,) + output) if total_loss is not None else output
834
+
835
+ return AlbertForPreTrainingOutput(
836
+ loss=total_loss,
837
+ prediction_logits=prediction_scores,
838
+ sop_logits=sop_scores,
839
+ hidden_states=outputs.hidden_states,
840
+ attentions=outputs.attentions,
841
+ )
842
+
843
+
844
+ class AlbertMLMHead(nn.Module):
845
+ def __init__(self, config: AlbertConfig):
846
+ super().__init__()
847
+
848
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
849
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
850
+ self.dense = nn.Linear(config.hidden_size, config.embedding_size)
851
+ self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
852
+ self.activation = ACT2FN[config.hidden_act]
853
+ self.decoder.bias = self.bias
854
+
855
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
856
+ hidden_states = self.dense(hidden_states)
857
+ hidden_states = self.activation(hidden_states)
858
+ hidden_states = self.LayerNorm(hidden_states)
859
+ hidden_states = self.decoder(hidden_states)
860
+
861
+ prediction_scores = hidden_states
862
+
863
+ return prediction_scores
864
+
865
+ def _tie_weights(self) -> None:
866
+ # For accelerate compatibility and to not break backward compatibility
867
+ if self.decoder.bias.device.type == "meta":
868
+ self.decoder.bias = self.bias
869
+ else:
870
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
871
+ self.bias = self.decoder.bias
872
+
873
+
874
+ class AlbertSOPHead(nn.Module):
875
+ def __init__(self, config: AlbertConfig):
876
+ super().__init__()
877
+
878
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
879
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
880
+
881
+ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
882
+ dropout_pooled_output = self.dropout(pooled_output)
883
+ logits = self.classifier(dropout_pooled_output)
884
+ return logits
885
+
886
+
887
+ @auto_docstring
888
+ class AlbertForMaskedLM(AlbertPreTrainedModel):
889
+ _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
890
+
891
+ def __init__(self, config):
892
+ super().__init__(config)
893
+
894
+ self.albert = AlbertModel(config, add_pooling_layer=False)
895
+ self.predictions = AlbertMLMHead(config)
896
+
897
+ # Initialize weights and apply final processing
898
+ self.post_init()
899
+
900
+ def get_output_embeddings(self) -> nn.Linear:
901
+ return self.predictions.decoder
902
+
903
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
904
+ self.predictions.decoder = new_embeddings
905
+ self.predictions.bias = new_embeddings.bias
906
+
907
+ def get_input_embeddings(self) -> nn.Embedding:
908
+ return self.albert.embeddings.word_embeddings
909
+
910
+ @auto_docstring
911
+ def forward(
912
+ self,
913
+ input_ids: Optional[torch.LongTensor] = None,
914
+ attention_mask: Optional[torch.FloatTensor] = None,
915
+ token_type_ids: Optional[torch.LongTensor] = None,
916
+ position_ids: Optional[torch.LongTensor] = None,
917
+ head_mask: Optional[torch.FloatTensor] = None,
918
+ inputs_embeds: Optional[torch.FloatTensor] = None,
919
+ labels: Optional[torch.LongTensor] = None,
920
+ output_attentions: Optional[bool] = None,
921
+ output_hidden_states: Optional[bool] = None,
922
+ return_dict: Optional[bool] = None,
923
+ ) -> Union[MaskedLMOutput, tuple]:
924
+ r"""
925
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
926
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
927
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
928
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
929
+
930
+ Example:
931
+
932
+ ```python
933
+ >>> import torch
934
+ >>> from transformers import AutoTokenizer, AlbertForMaskedLM
935
+
936
+ >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
937
+ >>> model = AlbertForMaskedLM.from_pretrained("albert/albert-base-v2")
938
+
939
+ >>> # add mask_token
940
+ >>> inputs = tokenizer("The capital of [MASK] is Paris.", return_tensors="pt")
941
+ >>> with torch.no_grad():
942
+ ... logits = model(**inputs).logits
943
+
944
+ >>> # retrieve index of [MASK]
945
+ >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
946
+ >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
947
+ >>> tokenizer.decode(predicted_token_id)
948
+ 'france'
949
+ ```
950
+
951
+ ```python
952
+ >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
953
+ >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
954
+ >>> outputs = model(**inputs, labels=labels)
955
+ >>> round(outputs.loss.item(), 2)
956
+ 0.81
957
+ ```
958
+ """
959
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
960
+
961
+ outputs = self.albert(
962
+ input_ids=input_ids,
963
+ attention_mask=attention_mask,
964
+ token_type_ids=token_type_ids,
965
+ position_ids=position_ids,
966
+ head_mask=head_mask,
967
+ inputs_embeds=inputs_embeds,
968
+ output_attentions=output_attentions,
969
+ output_hidden_states=output_hidden_states,
970
+ return_dict=return_dict,
971
+ )
972
+ sequence_outputs = outputs[0]
973
+
974
+ prediction_scores = self.predictions(sequence_outputs)
975
+
976
+ masked_lm_loss = None
977
+ if labels is not None:
978
+ loss_fct = CrossEntropyLoss()
979
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
980
+
981
+ if not return_dict:
982
+ output = (prediction_scores,) + outputs[2:]
983
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
984
+
985
+ return MaskedLMOutput(
986
+ loss=masked_lm_loss,
987
+ logits=prediction_scores,
988
+ hidden_states=outputs.hidden_states,
989
+ attentions=outputs.attentions,
990
+ )
991
+
992
+
993
+ @auto_docstring(
994
+ custom_intro="""
995
+ Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
996
+ output) e.g. for GLUE tasks.
997
+ """
998
+ )
999
+ class AlbertForSequenceClassification(AlbertPreTrainedModel):
1000
+ def __init__(self, config: AlbertConfig):
1001
+ super().__init__(config)
1002
+ self.num_labels = config.num_labels
1003
+ self.config = config
1004
+
1005
+ self.albert = AlbertModel(config)
1006
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
1007
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
1008
+
1009
+ # Initialize weights and apply final processing
1010
+ self.post_init()
1011
+
1012
+ @auto_docstring
1013
+ def forward(
1014
+ self,
1015
+ input_ids: Optional[torch.LongTensor] = None,
1016
+ attention_mask: Optional[torch.FloatTensor] = None,
1017
+ token_type_ids: Optional[torch.LongTensor] = None,
1018
+ position_ids: Optional[torch.LongTensor] = None,
1019
+ head_mask: Optional[torch.FloatTensor] = None,
1020
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1021
+ labels: Optional[torch.LongTensor] = None,
1022
+ output_attentions: Optional[bool] = None,
1023
+ output_hidden_states: Optional[bool] = None,
1024
+ return_dict: Optional[bool] = None,
1025
+ ) -> Union[SequenceClassifierOutput, tuple]:
1026
+ r"""
1027
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1028
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1029
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1030
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1031
+ """
1032
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1033
+
1034
+ outputs = self.albert(
1035
+ input_ids=input_ids,
1036
+ attention_mask=attention_mask,
1037
+ token_type_ids=token_type_ids,
1038
+ position_ids=position_ids,
1039
+ head_mask=head_mask,
1040
+ inputs_embeds=inputs_embeds,
1041
+ output_attentions=output_attentions,
1042
+ output_hidden_states=output_hidden_states,
1043
+ return_dict=return_dict,
1044
+ )
1045
+
1046
+ pooled_output = outputs[1]
1047
+
1048
+ pooled_output = self.dropout(pooled_output)
1049
+ logits = self.classifier(pooled_output)
1050
+
1051
+ loss = None
1052
+ if labels is not None:
1053
+ if self.config.problem_type is None:
1054
+ if self.num_labels == 1:
1055
+ self.config.problem_type = "regression"
1056
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1057
+ self.config.problem_type = "single_label_classification"
1058
+ else:
1059
+ self.config.problem_type = "multi_label_classification"
1060
+
1061
+ if self.config.problem_type == "regression":
1062
+ loss_fct = MSELoss()
1063
+ if self.num_labels == 1:
1064
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1065
+ else:
1066
+ loss = loss_fct(logits, labels)
1067
+ elif self.config.problem_type == "single_label_classification":
1068
+ loss_fct = CrossEntropyLoss()
1069
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1070
+ elif self.config.problem_type == "multi_label_classification":
1071
+ loss_fct = BCEWithLogitsLoss()
1072
+ loss = loss_fct(logits, labels)
1073
+
1074
+ if not return_dict:
1075
+ output = (logits,) + outputs[2:]
1076
+ return ((loss,) + output) if loss is not None else output
1077
+
1078
+ return SequenceClassifierOutput(
1079
+ loss=loss,
1080
+ logits=logits,
1081
+ hidden_states=outputs.hidden_states,
1082
+ attentions=outputs.attentions,
1083
+ )
1084
+
1085
+
1086
+ @auto_docstring
1087
+ class AlbertForTokenClassification(AlbertPreTrainedModel):
1088
+ def __init__(self, config: AlbertConfig):
1089
+ super().__init__(config)
1090
+ self.num_labels = config.num_labels
1091
+
1092
+ self.albert = AlbertModel(config, add_pooling_layer=False)
1093
+ classifier_dropout_prob = (
1094
+ config.classifier_dropout_prob
1095
+ if config.classifier_dropout_prob is not None
1096
+ else config.hidden_dropout_prob
1097
+ )
1098
+ self.dropout = nn.Dropout(classifier_dropout_prob)
1099
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
1100
+
1101
+ # Initialize weights and apply final processing
1102
+ self.post_init()
1103
+
1104
+ @auto_docstring
1105
+ def forward(
1106
+ self,
1107
+ input_ids: Optional[torch.LongTensor] = None,
1108
+ attention_mask: Optional[torch.FloatTensor] = None,
1109
+ token_type_ids: Optional[torch.LongTensor] = None,
1110
+ position_ids: Optional[torch.LongTensor] = None,
1111
+ head_mask: Optional[torch.FloatTensor] = None,
1112
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1113
+ labels: Optional[torch.LongTensor] = None,
1114
+ output_attentions: Optional[bool] = None,
1115
+ output_hidden_states: Optional[bool] = None,
1116
+ return_dict: Optional[bool] = None,
1117
+ ) -> Union[TokenClassifierOutput, tuple]:
1118
+ r"""
1119
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1120
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1121
+ """
1122
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1123
+
1124
+ outputs = self.albert(
1125
+ input_ids,
1126
+ attention_mask=attention_mask,
1127
+ token_type_ids=token_type_ids,
1128
+ position_ids=position_ids,
1129
+ head_mask=head_mask,
1130
+ inputs_embeds=inputs_embeds,
1131
+ output_attentions=output_attentions,
1132
+ output_hidden_states=output_hidden_states,
1133
+ return_dict=return_dict,
1134
+ )
1135
+
1136
+ sequence_output = outputs[0]
1137
+
1138
+ sequence_output = self.dropout(sequence_output)
1139
+ logits = self.classifier(sequence_output)
1140
+
1141
+ loss = None
1142
+ if labels is not None:
1143
+ loss_fct = CrossEntropyLoss()
1144
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1145
+
1146
+ if not return_dict:
1147
+ output = (logits,) + outputs[2:]
1148
+ return ((loss,) + output) if loss is not None else output
1149
+
1150
+ return TokenClassifierOutput(
1151
+ loss=loss,
1152
+ logits=logits,
1153
+ hidden_states=outputs.hidden_states,
1154
+ attentions=outputs.attentions,
1155
+ )
1156
+
1157
+
1158
+ @auto_docstring
1159
+ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
1160
+ def __init__(self, config: AlbertConfig):
1161
+ super().__init__(config)
1162
+ self.num_labels = config.num_labels
1163
+
1164
+ self.albert = AlbertModel(config, add_pooling_layer=False)
1165
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1166
+
1167
+ # Initialize weights and apply final processing
1168
+ self.post_init()
1169
+
1170
+ @auto_docstring
1171
+ def forward(
1172
+ self,
1173
+ input_ids: Optional[torch.LongTensor] = None,
1174
+ attention_mask: Optional[torch.FloatTensor] = None,
1175
+ token_type_ids: Optional[torch.LongTensor] = None,
1176
+ position_ids: Optional[torch.LongTensor] = None,
1177
+ head_mask: Optional[torch.FloatTensor] = None,
1178
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1179
+ start_positions: Optional[torch.LongTensor] = None,
1180
+ end_positions: Optional[torch.LongTensor] = None,
1181
+ output_attentions: Optional[bool] = None,
1182
+ output_hidden_states: Optional[bool] = None,
1183
+ return_dict: Optional[bool] = None,
1184
+ ) -> Union[AlbertForPreTrainingOutput, tuple]:
1185
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1186
+
1187
+ outputs = self.albert(
1188
+ input_ids=input_ids,
1189
+ attention_mask=attention_mask,
1190
+ token_type_ids=token_type_ids,
1191
+ position_ids=position_ids,
1192
+ head_mask=head_mask,
1193
+ inputs_embeds=inputs_embeds,
1194
+ output_attentions=output_attentions,
1195
+ output_hidden_states=output_hidden_states,
1196
+ return_dict=return_dict,
1197
+ )
1198
+
1199
+ sequence_output = outputs[0]
1200
+
1201
+ logits: torch.Tensor = self.qa_outputs(sequence_output)
1202
+ start_logits, end_logits = logits.split(1, dim=-1)
1203
+ start_logits = start_logits.squeeze(-1).contiguous()
1204
+ end_logits = end_logits.squeeze(-1).contiguous()
1205
+
1206
+ total_loss = None
1207
+ if start_positions is not None and end_positions is not None:
1208
+ # If we are on multi-GPU, split add a dimension
1209
+ if len(start_positions.size()) > 1:
1210
+ start_positions = start_positions.squeeze(-1)
1211
+ if len(end_positions.size()) > 1:
1212
+ end_positions = end_positions.squeeze(-1)
1213
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1214
+ ignored_index = start_logits.size(1)
1215
+ start_positions = start_positions.clamp(0, ignored_index)
1216
+ end_positions = end_positions.clamp(0, ignored_index)
1217
+
1218
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1219
+ start_loss = loss_fct(start_logits, start_positions)
1220
+ end_loss = loss_fct(end_logits, end_positions)
1221
+ total_loss = (start_loss + end_loss) / 2
1222
+
1223
+ if not return_dict:
1224
+ output = (start_logits, end_logits) + outputs[2:]
1225
+ return ((total_loss,) + output) if total_loss is not None else output
1226
+
1227
+ return QuestionAnsweringModelOutput(
1228
+ loss=total_loss,
1229
+ start_logits=start_logits,
1230
+ end_logits=end_logits,
1231
+ hidden_states=outputs.hidden_states,
1232
+ attentions=outputs.attentions,
1233
+ )
1234
+
1235
+
1236
+ @auto_docstring
1237
+ class AlbertForMultipleChoice(AlbertPreTrainedModel):
1238
+ def __init__(self, config: AlbertConfig):
1239
+ super().__init__(config)
1240
+
1241
+ self.albert = AlbertModel(config)
1242
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
1243
+ self.classifier = nn.Linear(config.hidden_size, 1)
1244
+
1245
+ # Initialize weights and apply final processing
1246
+ self.post_init()
1247
+
1248
+ @auto_docstring
1249
+ def forward(
1250
+ self,
1251
+ input_ids: Optional[torch.LongTensor] = None,
1252
+ attention_mask: Optional[torch.FloatTensor] = None,
1253
+ token_type_ids: Optional[torch.LongTensor] = None,
1254
+ position_ids: Optional[torch.LongTensor] = None,
1255
+ head_mask: Optional[torch.FloatTensor] = None,
1256
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1257
+ labels: Optional[torch.LongTensor] = None,
1258
+ output_attentions: Optional[bool] = None,
1259
+ output_hidden_states: Optional[bool] = None,
1260
+ return_dict: Optional[bool] = None,
1261
+ ) -> Union[AlbertForPreTrainingOutput, tuple]:
1262
+ r"""
1263
+ input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
1264
+ Indices of input sequence tokens in the vocabulary.
1265
+
1266
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
1267
+ [`PreTrainedTokenizer.encode`] for details.
1268
+
1269
+ [What are input IDs?](../glossary#input-ids)
1270
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
1271
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
1272
+ 1]`:
1273
+
1274
+ - 0 corresponds to a *sentence A* token,
1275
+ - 1 corresponds to a *sentence B* token.
1276
+
1277
+ [What are token type IDs?](../glossary#token-type-ids)
1278
+ position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
1279
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1280
+ config.max_position_embeddings - 1]`.
1281
+
1282
+ [What are position IDs?](../glossary#position-ids)
1283
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
1284
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1285
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1286
+ model's internal embedding lookup matrix.
1287
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1288
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1289
+ num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see
1290
+ *input_ids* above)
1291
+ """
1292
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1293
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1294
+
1295
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1296
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1297
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1298
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1299
+ inputs_embeds = (
1300
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1301
+ if inputs_embeds is not None
1302
+ else None
1303
+ )
1304
+ outputs = self.albert(
1305
+ input_ids,
1306
+ attention_mask=attention_mask,
1307
+ token_type_ids=token_type_ids,
1308
+ position_ids=position_ids,
1309
+ head_mask=head_mask,
1310
+ inputs_embeds=inputs_embeds,
1311
+ output_attentions=output_attentions,
1312
+ output_hidden_states=output_hidden_states,
1313
+ return_dict=return_dict,
1314
+ )
1315
+
1316
+ pooled_output = outputs[1]
1317
+
1318
+ pooled_output = self.dropout(pooled_output)
1319
+ logits: torch.Tensor = self.classifier(pooled_output)
1320
+ reshaped_logits = logits.view(-1, num_choices)
1321
+
1322
+ loss = None
1323
+ if labels is not None:
1324
+ loss_fct = CrossEntropyLoss()
1325
+ loss = loss_fct(reshaped_logits, labels)
1326
+
1327
+ if not return_dict:
1328
+ output = (reshaped_logits,) + outputs[2:]
1329
+ return ((loss,) + output) if loss is not None else output
1330
+
1331
+ return MultipleChoiceModelOutput(
1332
+ loss=loss,
1333
+ logits=reshaped_logits,
1334
+ hidden_states=outputs.hidden_states,
1335
+ attentions=outputs.attentions,
1336
+ )
1337
+
1338
+
1339
+ __all__ = [
1340
+ "load_tf_weights_in_albert",
1341
+ "AlbertPreTrainedModel",
1342
+ "AlbertModel",
1343
+ "AlbertForPreTraining",
1344
+ "AlbertForMaskedLM",
1345
+ "AlbertForSequenceClassification",
1346
+ "AlbertForTokenClassification",
1347
+ "AlbertForQuestionAnswering",
1348
+ "AlbertForMultipleChoice",
1349
+ ]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_flax_albert.py ADDED
@@ -0,0 +1,1132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Google AI, Google Brain and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Callable, Optional
17
+
18
+ import flax
19
+ import flax.linen as nn
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
24
+ from flax.linen.attention import dot_product_attention_weights
25
+ from flax.traverse_util import flatten_dict, unflatten_dict
26
+ from jax import lax
27
+
28
+ from ...modeling_flax_outputs import (
29
+ FlaxBaseModelOutput,
30
+ FlaxBaseModelOutputWithPooling,
31
+ FlaxMaskedLMOutput,
32
+ FlaxMultipleChoiceModelOutput,
33
+ FlaxQuestionAnsweringModelOutput,
34
+ FlaxSequenceClassifierOutput,
35
+ FlaxTokenClassifierOutput,
36
+ )
37
+ from ...modeling_flax_utils import (
38
+ ACT2FN,
39
+ FlaxPreTrainedModel,
40
+ append_call_sample_docstring,
41
+ append_replace_return_docstrings,
42
+ overwrite_call_docstring,
43
+ )
44
+ from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
45
+ from .configuration_albert import AlbertConfig
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+ _CHECKPOINT_FOR_DOC = "albert/albert-base-v2"
51
+ _CONFIG_FOR_DOC = "AlbertConfig"
52
+
53
+
54
+ @flax.struct.dataclass
55
+ class FlaxAlbertForPreTrainingOutput(ModelOutput):
56
+ """
57
+ Output type of [`FlaxAlbertForPreTraining`].
58
+
59
+ Args:
60
+ prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
61
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
62
+ sop_logits (`jnp.ndarray` of shape `(batch_size, 2)`):
63
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
64
+ before SoftMax).
65
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
66
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
67
+ `(batch_size, sequence_length, hidden_size)`.
68
+
69
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
70
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
71
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
72
+ sequence_length)`.
73
+
74
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
75
+ heads.
76
+ """
77
+
78
+ prediction_logits: jnp.ndarray = None
79
+ sop_logits: jnp.ndarray = None
80
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
81
+ attentions: Optional[tuple[jnp.ndarray]] = None
82
+
83
+
84
+ ALBERT_START_DOCSTRING = r"""
85
+
86
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
87
+ library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
88
+
89
+ This model is also a
90
+ [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
91
+ a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
92
+ behavior.
93
+
94
+ Finally, this model supports inherent JAX features such as:
95
+
96
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
97
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
98
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
99
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
100
+
101
+ Parameters:
102
+ config ([`AlbertConfig`]): Model configuration class with all the parameters of the model.
103
+ Initializing with a config file does not load the weights associated with the model, only the
104
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
105
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
106
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
107
+ `jax.numpy.bfloat16` (on TPUs).
108
+
109
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
110
+ specified all the computation will be performed with the given `dtype`.
111
+
112
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
113
+ parameters.**
114
+
115
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
116
+ [`~FlaxPreTrainedModel.to_bf16`].
117
+ """
118
+
119
+ ALBERT_INPUTS_DOCSTRING = r"""
120
+ Args:
121
+ input_ids (`numpy.ndarray` of shape `({0})`):
122
+ Indices of input sequence tokens in the vocabulary.
123
+
124
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
125
+ [`PreTrainedTokenizer.__call__`] for details.
126
+
127
+ [What are input IDs?](../glossary#input-ids)
128
+ attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
129
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
130
+
131
+ - 1 for tokens that are **not masked**,
132
+ - 0 for tokens that are **masked**.
133
+
134
+ [What are attention masks?](../glossary#attention-mask)
135
+ token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
136
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
137
+ 1]`:
138
+
139
+ - 0 corresponds to a *sentence A* token,
140
+ - 1 corresponds to a *sentence B* token.
141
+
142
+ [What are token type IDs?](../glossary#token-type-ids)
143
+ position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
144
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
145
+ config.max_position_embeddings - 1]`.
146
+ return_dict (`bool`, *optional*):
147
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
148
+
149
+ """
150
+
151
+
152
+ class FlaxAlbertEmbeddings(nn.Module):
153
+ """Construct the embeddings from word, position and token_type embeddings."""
154
+
155
+ config: AlbertConfig
156
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
157
+
158
+ def setup(self):
159
+ self.word_embeddings = nn.Embed(
160
+ self.config.vocab_size,
161
+ self.config.embedding_size,
162
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
163
+ )
164
+ self.position_embeddings = nn.Embed(
165
+ self.config.max_position_embeddings,
166
+ self.config.embedding_size,
167
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
168
+ )
169
+ self.token_type_embeddings = nn.Embed(
170
+ self.config.type_vocab_size,
171
+ self.config.embedding_size,
172
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
173
+ )
174
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
175
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
176
+
177
+ def __call__(self, input_ids, token_type_ids, position_ids, deterministic: bool = True):
178
+ # Embed
179
+ inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
180
+ position_embeds = self.position_embeddings(position_ids.astype("i4"))
181
+ token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
182
+
183
+ # Sum all embeddings
184
+ hidden_states = inputs_embeds + token_type_embeddings + position_embeds
185
+
186
+ # Layer Norm
187
+ hidden_states = self.LayerNorm(hidden_states)
188
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
189
+ return hidden_states
190
+
191
+
192
+ class FlaxAlbertSelfAttention(nn.Module):
193
+ config: AlbertConfig
194
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
195
+
196
+ def setup(self):
197
+ if self.config.hidden_size % self.config.num_attention_heads != 0:
198
+ raise ValueError(
199
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
200
+ " : {self.config.num_attention_heads}"
201
+ )
202
+
203
+ self.query = nn.Dense(
204
+ self.config.hidden_size,
205
+ dtype=self.dtype,
206
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
207
+ )
208
+ self.key = nn.Dense(
209
+ self.config.hidden_size,
210
+ dtype=self.dtype,
211
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
212
+ )
213
+ self.value = nn.Dense(
214
+ self.config.hidden_size,
215
+ dtype=self.dtype,
216
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
217
+ )
218
+ self.dense = nn.Dense(
219
+ self.config.hidden_size,
220
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
221
+ dtype=self.dtype,
222
+ )
223
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
224
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
225
+
226
+ def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
227
+ head_dim = self.config.hidden_size // self.config.num_attention_heads
228
+
229
+ query_states = self.query(hidden_states).reshape(
230
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
231
+ )
232
+ value_states = self.value(hidden_states).reshape(
233
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
234
+ )
235
+ key_states = self.key(hidden_states).reshape(
236
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
237
+ )
238
+
239
+ # Convert the boolean attention mask to an attention bias.
240
+ if attention_mask is not None:
241
+ # attention mask in the form of attention bias
242
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
243
+ attention_bias = lax.select(
244
+ attention_mask > 0,
245
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
246
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
247
+ )
248
+ else:
249
+ attention_bias = None
250
+
251
+ dropout_rng = None
252
+ if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
253
+ dropout_rng = self.make_rng("dropout")
254
+
255
+ attn_weights = dot_product_attention_weights(
256
+ query_states,
257
+ key_states,
258
+ bias=attention_bias,
259
+ dropout_rng=dropout_rng,
260
+ dropout_rate=self.config.attention_probs_dropout_prob,
261
+ broadcast_dropout=True,
262
+ deterministic=deterministic,
263
+ dtype=self.dtype,
264
+ precision=None,
265
+ )
266
+
267
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
268
+ attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
269
+
270
+ projected_attn_output = self.dense(attn_output)
271
+ projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic)
272
+ layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states)
273
+ outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,)
274
+ return outputs
275
+
276
+
277
+ class FlaxAlbertLayer(nn.Module):
278
+ config: AlbertConfig
279
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
280
+
281
+ def setup(self):
282
+ self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype)
283
+ self.ffn = nn.Dense(
284
+ self.config.intermediate_size,
285
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
286
+ dtype=self.dtype,
287
+ )
288
+ self.activation = ACT2FN[self.config.hidden_act]
289
+ self.ffn_output = nn.Dense(
290
+ self.config.hidden_size,
291
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
292
+ dtype=self.dtype,
293
+ )
294
+ self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
295
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
296
+
297
+ def __call__(
298
+ self,
299
+ hidden_states,
300
+ attention_mask,
301
+ deterministic: bool = True,
302
+ output_attentions: bool = False,
303
+ ):
304
+ attention_outputs = self.attention(
305
+ hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
306
+ )
307
+ attention_output = attention_outputs[0]
308
+ ffn_output = self.ffn(attention_output)
309
+ ffn_output = self.activation(ffn_output)
310
+ ffn_output = self.ffn_output(ffn_output)
311
+ ffn_output = self.dropout(ffn_output, deterministic=deterministic)
312
+ hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)
313
+
314
+ outputs = (hidden_states,)
315
+
316
+ if output_attentions:
317
+ outputs += (attention_outputs[1],)
318
+ return outputs
319
+
320
+
321
+ class FlaxAlbertLayerCollection(nn.Module):
322
+ config: AlbertConfig
323
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
324
+
325
+ def setup(self):
326
+ self.layers = [
327
+ FlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num)
328
+ ]
329
+
330
+ def __call__(
331
+ self,
332
+ hidden_states,
333
+ attention_mask,
334
+ deterministic: bool = True,
335
+ output_attentions: bool = False,
336
+ output_hidden_states: bool = False,
337
+ ):
338
+ layer_hidden_states = ()
339
+ layer_attentions = ()
340
+
341
+ for layer_index, albert_layer in enumerate(self.layers):
342
+ layer_output = albert_layer(
343
+ hidden_states,
344
+ attention_mask,
345
+ deterministic=deterministic,
346
+ output_attentions=output_attentions,
347
+ )
348
+ hidden_states = layer_output[0]
349
+
350
+ if output_attentions:
351
+ layer_attentions = layer_attentions + (layer_output[1],)
352
+
353
+ if output_hidden_states:
354
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
355
+
356
+ outputs = (hidden_states,)
357
+ if output_hidden_states:
358
+ outputs = outputs + (layer_hidden_states,)
359
+ if output_attentions:
360
+ outputs = outputs + (layer_attentions,)
361
+ return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
362
+
363
+
364
+ class FlaxAlbertLayerCollections(nn.Module):
365
+ config: AlbertConfig
366
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
367
+ layer_index: Optional[str] = None
368
+
369
+ def setup(self):
370
+ self.albert_layers = FlaxAlbertLayerCollection(self.config, dtype=self.dtype)
371
+
372
+ def __call__(
373
+ self,
374
+ hidden_states,
375
+ attention_mask,
376
+ deterministic: bool = True,
377
+ output_attentions: bool = False,
378
+ output_hidden_states: bool = False,
379
+ ):
380
+ outputs = self.albert_layers(
381
+ hidden_states,
382
+ attention_mask,
383
+ deterministic=deterministic,
384
+ output_attentions=output_attentions,
385
+ output_hidden_states=output_hidden_states,
386
+ )
387
+ return outputs
388
+
389
+
390
+ class FlaxAlbertLayerGroups(nn.Module):
391
+ config: AlbertConfig
392
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
393
+
394
+ def setup(self):
395
+ self.layers = [
396
+ FlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype)
397
+ for i in range(self.config.num_hidden_groups)
398
+ ]
399
+
400
+ def __call__(
401
+ self,
402
+ hidden_states,
403
+ attention_mask,
404
+ deterministic: bool = True,
405
+ output_attentions: bool = False,
406
+ output_hidden_states: bool = False,
407
+ return_dict: bool = True,
408
+ ):
409
+ all_attentions = () if output_attentions else None
410
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
411
+
412
+ for i in range(self.config.num_hidden_layers):
413
+ # Index of the hidden group
414
+ group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
415
+ layer_group_output = self.layers[group_idx](
416
+ hidden_states,
417
+ attention_mask,
418
+ deterministic=deterministic,
419
+ output_attentions=output_attentions,
420
+ output_hidden_states=output_hidden_states,
421
+ )
422
+ hidden_states = layer_group_output[0]
423
+
424
+ if output_attentions:
425
+ all_attentions = all_attentions + layer_group_output[-1]
426
+
427
+ if output_hidden_states:
428
+ all_hidden_states = all_hidden_states + (hidden_states,)
429
+
430
+ if not return_dict:
431
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
432
+ return FlaxBaseModelOutput(
433
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
434
+ )
435
+
436
+
437
+ class FlaxAlbertEncoder(nn.Module):
438
+ config: AlbertConfig
439
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
440
+
441
+ def setup(self):
442
+ self.embedding_hidden_mapping_in = nn.Dense(
443
+ self.config.hidden_size,
444
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
445
+ dtype=self.dtype,
446
+ )
447
+ self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype)
448
+
449
+ def __call__(
450
+ self,
451
+ hidden_states,
452
+ attention_mask,
453
+ deterministic: bool = True,
454
+ output_attentions: bool = False,
455
+ output_hidden_states: bool = False,
456
+ return_dict: bool = True,
457
+ ):
458
+ hidden_states = self.embedding_hidden_mapping_in(hidden_states)
459
+ return self.albert_layer_groups(
460
+ hidden_states,
461
+ attention_mask,
462
+ deterministic=deterministic,
463
+ output_attentions=output_attentions,
464
+ output_hidden_states=output_hidden_states,
465
+ )
466
+
467
+
468
+ class FlaxAlbertOnlyMLMHead(nn.Module):
469
+ config: AlbertConfig
470
+ dtype: jnp.dtype = jnp.float32
471
+ bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
472
+
473
+ def setup(self):
474
+ self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype)
475
+ self.activation = ACT2FN[self.config.hidden_act]
476
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
477
+ self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
478
+ self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
479
+
480
+ def __call__(self, hidden_states, shared_embedding=None):
481
+ hidden_states = self.dense(hidden_states)
482
+ hidden_states = self.activation(hidden_states)
483
+ hidden_states = self.LayerNorm(hidden_states)
484
+
485
+ if shared_embedding is not None:
486
+ hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
487
+ else:
488
+ hidden_states = self.decoder(hidden_states)
489
+
490
+ hidden_states += self.bias
491
+ return hidden_states
492
+
493
+
494
+ class FlaxAlbertSOPHead(nn.Module):
495
+ config: AlbertConfig
496
+ dtype: jnp.dtype = jnp.float32
497
+
498
+ def setup(self):
499
+ self.dropout = nn.Dropout(self.config.classifier_dropout_prob)
500
+ self.classifier = nn.Dense(2, dtype=self.dtype)
501
+
502
+ def __call__(self, pooled_output, deterministic=True):
503
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
504
+ logits = self.classifier(pooled_output)
505
+ return logits
506
+
507
+
508
+ class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel):
509
+ """
510
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
511
+ models.
512
+ """
513
+
514
+ config_class = AlbertConfig
515
+ base_model_prefix = "albert"
516
+ module_class: nn.Module = None
517
+
518
+ def __init__(
519
+ self,
520
+ config: AlbertConfig,
521
+ input_shape: tuple = (1, 1),
522
+ seed: int = 0,
523
+ dtype: jnp.dtype = jnp.float32,
524
+ _do_init: bool = True,
525
+ **kwargs,
526
+ ):
527
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
528
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
529
+
530
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict:
531
+ # init input tensors
532
+ input_ids = jnp.zeros(input_shape, dtype="i4")
533
+ token_type_ids = jnp.zeros_like(input_ids)
534
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
535
+ attention_mask = jnp.ones_like(input_ids)
536
+
537
+ params_rng, dropout_rng = jax.random.split(rng)
538
+ rngs = {"params": params_rng, "dropout": dropout_rng}
539
+
540
+ random_params = self.module.init(
541
+ rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False
542
+ )["params"]
543
+
544
+ if params is not None:
545
+ random_params = flatten_dict(unfreeze(random_params))
546
+ params = flatten_dict(unfreeze(params))
547
+ for missing_key in self._missing_keys:
548
+ params[missing_key] = random_params[missing_key]
549
+ self._missing_keys = set()
550
+ return freeze(unflatten_dict(params))
551
+ else:
552
+ return random_params
553
+
554
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
555
+ def __call__(
556
+ self,
557
+ input_ids,
558
+ attention_mask=None,
559
+ token_type_ids=None,
560
+ position_ids=None,
561
+ params: Optional[dict] = None,
562
+ dropout_rng: jax.random.PRNGKey = None,
563
+ train: bool = False,
564
+ output_attentions: Optional[bool] = None,
565
+ output_hidden_states: Optional[bool] = None,
566
+ return_dict: Optional[bool] = None,
567
+ ):
568
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
569
+ output_hidden_states = (
570
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
571
+ )
572
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
573
+
574
+ # init input tensors if not passed
575
+ if token_type_ids is None:
576
+ token_type_ids = jnp.zeros_like(input_ids)
577
+
578
+ if position_ids is None:
579
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
580
+
581
+ if attention_mask is None:
582
+ attention_mask = jnp.ones_like(input_ids)
583
+
584
+ # Handle any PRNG if needed
585
+ rngs = {}
586
+ if dropout_rng is not None:
587
+ rngs["dropout"] = dropout_rng
588
+
589
+ return self.module.apply(
590
+ {"params": params or self.params},
591
+ jnp.array(input_ids, dtype="i4"),
592
+ jnp.array(attention_mask, dtype="i4"),
593
+ jnp.array(token_type_ids, dtype="i4"),
594
+ jnp.array(position_ids, dtype="i4"),
595
+ not train,
596
+ output_attentions,
597
+ output_hidden_states,
598
+ return_dict,
599
+ rngs=rngs,
600
+ )
601
+
602
+
603
+ class FlaxAlbertModule(nn.Module):
604
+ config: AlbertConfig
605
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
606
+ add_pooling_layer: bool = True
607
+
608
+ def setup(self):
609
+ self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype)
610
+ self.encoder = FlaxAlbertEncoder(self.config, dtype=self.dtype)
611
+ if self.add_pooling_layer:
612
+ self.pooler = nn.Dense(
613
+ self.config.hidden_size,
614
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
615
+ dtype=self.dtype,
616
+ name="pooler",
617
+ )
618
+ self.pooler_activation = nn.tanh
619
+ else:
620
+ self.pooler = None
621
+ self.pooler_activation = None
622
+
623
+ def __call__(
624
+ self,
625
+ input_ids,
626
+ attention_mask,
627
+ token_type_ids: Optional[np.ndarray] = None,
628
+ position_ids: Optional[np.ndarray] = None,
629
+ deterministic: bool = True,
630
+ output_attentions: bool = False,
631
+ output_hidden_states: bool = False,
632
+ return_dict: bool = True,
633
+ ):
634
+ # make sure `token_type_ids` is correctly initialized when not passed
635
+ if token_type_ids is None:
636
+ token_type_ids = jnp.zeros_like(input_ids)
637
+
638
+ # make sure `position_ids` is correctly initialized when not passed
639
+ if position_ids is None:
640
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
641
+
642
+ hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic)
643
+
644
+ outputs = self.encoder(
645
+ hidden_states,
646
+ attention_mask,
647
+ deterministic=deterministic,
648
+ output_attentions=output_attentions,
649
+ output_hidden_states=output_hidden_states,
650
+ return_dict=return_dict,
651
+ )
652
+ hidden_states = outputs[0]
653
+ if self.add_pooling_layer:
654
+ pooled = self.pooler(hidden_states[:, 0])
655
+ pooled = self.pooler_activation(pooled)
656
+ else:
657
+ pooled = None
658
+
659
+ if not return_dict:
660
+ # if pooled is None, don't return it
661
+ if pooled is None:
662
+ return (hidden_states,) + outputs[1:]
663
+ return (hidden_states, pooled) + outputs[1:]
664
+
665
+ return FlaxBaseModelOutputWithPooling(
666
+ last_hidden_state=hidden_states,
667
+ pooler_output=pooled,
668
+ hidden_states=outputs.hidden_states,
669
+ attentions=outputs.attentions,
670
+ )
671
+
672
+
673
+ @add_start_docstrings(
674
+ "The bare Albert Model transformer outputting raw hidden-states without any specific head on top.",
675
+ ALBERT_START_DOCSTRING,
676
+ )
677
+ class FlaxAlbertModel(FlaxAlbertPreTrainedModel):
678
+ module_class = FlaxAlbertModule
679
+
680
+
681
+ append_call_sample_docstring(FlaxAlbertModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)
682
+
683
+
684
+ class FlaxAlbertForPreTrainingModule(nn.Module):
685
+ config: AlbertConfig
686
+ dtype: jnp.dtype = jnp.float32
687
+
688
+ def setup(self):
689
+ self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
690
+ self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
691
+ self.sop_classifier = FlaxAlbertSOPHead(config=self.config, dtype=self.dtype)
692
+
693
+ def __call__(
694
+ self,
695
+ input_ids,
696
+ attention_mask,
697
+ token_type_ids,
698
+ position_ids,
699
+ deterministic: bool = True,
700
+ output_attentions: bool = False,
701
+ output_hidden_states: bool = False,
702
+ return_dict: bool = True,
703
+ ):
704
+ # Model
705
+ outputs = self.albert(
706
+ input_ids,
707
+ attention_mask,
708
+ token_type_ids,
709
+ position_ids,
710
+ deterministic=deterministic,
711
+ output_attentions=output_attentions,
712
+ output_hidden_states=output_hidden_states,
713
+ return_dict=return_dict,
714
+ )
715
+
716
+ if self.config.tie_word_embeddings:
717
+ shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
718
+ else:
719
+ shared_embedding = None
720
+
721
+ hidden_states = outputs[0]
722
+ pooled_output = outputs[1]
723
+
724
+ prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
725
+ sop_scores = self.sop_classifier(pooled_output, deterministic=deterministic)
726
+
727
+ if not return_dict:
728
+ return (prediction_scores, sop_scores) + outputs[2:]
729
+
730
+ return FlaxAlbertForPreTrainingOutput(
731
+ prediction_logits=prediction_scores,
732
+ sop_logits=sop_scores,
733
+ hidden_states=outputs.hidden_states,
734
+ attentions=outputs.attentions,
735
+ )
736
+
737
+
738
+ @add_start_docstrings(
739
+ """
740
+ Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
741
+ `sentence order prediction (classification)` head.
742
+ """,
743
+ ALBERT_START_DOCSTRING,
744
+ )
745
+ class FlaxAlbertForPreTraining(FlaxAlbertPreTrainedModel):
746
+ module_class = FlaxAlbertForPreTrainingModule
747
+
748
+
749
+ FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING = """
750
+ Returns:
751
+
752
+ Example:
753
+
754
+ ```python
755
+ >>> from transformers import AutoTokenizer, FlaxAlbertForPreTraining
756
+
757
+ >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
758
+ >>> model = FlaxAlbertForPreTraining.from_pretrained("albert/albert-base-v2")
759
+
760
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
761
+ >>> outputs = model(**inputs)
762
+
763
+ >>> prediction_logits = outputs.prediction_logits
764
+ >>> seq_relationship_logits = outputs.sop_logits
765
+ ```
766
+ """
767
+
768
+ overwrite_call_docstring(
769
+ FlaxAlbertForPreTraining,
770
+ ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING,
771
+ )
772
+ append_replace_return_docstrings(
773
+ FlaxAlbertForPreTraining, output_type=FlaxAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
774
+ )
775
+
776
+
777
+ class FlaxAlbertForMaskedLMModule(nn.Module):
778
+ config: AlbertConfig
779
+ dtype: jnp.dtype = jnp.float32
780
+
781
+ def setup(self):
782
+ self.albert = FlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
783
+ self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
784
+
785
+ def __call__(
786
+ self,
787
+ input_ids,
788
+ attention_mask,
789
+ token_type_ids,
790
+ position_ids,
791
+ deterministic: bool = True,
792
+ output_attentions: bool = False,
793
+ output_hidden_states: bool = False,
794
+ return_dict: bool = True,
795
+ ):
796
+ # Model
797
+ outputs = self.albert(
798
+ input_ids,
799
+ attention_mask,
800
+ token_type_ids,
801
+ position_ids,
802
+ deterministic=deterministic,
803
+ output_attentions=output_attentions,
804
+ output_hidden_states=output_hidden_states,
805
+ return_dict=return_dict,
806
+ )
807
+
808
+ hidden_states = outputs[0]
809
+ if self.config.tie_word_embeddings:
810
+ shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
811
+ else:
812
+ shared_embedding = None
813
+
814
+ # Compute the prediction scores
815
+ logits = self.predictions(hidden_states, shared_embedding=shared_embedding)
816
+
817
+ if not return_dict:
818
+ return (logits,) + outputs[1:]
819
+
820
+ return FlaxMaskedLMOutput(
821
+ logits=logits,
822
+ hidden_states=outputs.hidden_states,
823
+ attentions=outputs.attentions,
824
+ )
825
+
826
+
827
+ @add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING)
828
+ class FlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel):
829
+ module_class = FlaxAlbertForMaskedLMModule
830
+
831
+
832
+ append_call_sample_docstring(
833
+ FlaxAlbertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC, revision="refs/pr/11"
834
+ )
835
+
836
+
837
+ class FlaxAlbertForSequenceClassificationModule(nn.Module):
838
+ config: AlbertConfig
839
+ dtype: jnp.dtype = jnp.float32
840
+
841
+ def setup(self):
842
+ self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
843
+ classifier_dropout = (
844
+ self.config.classifier_dropout_prob
845
+ if self.config.classifier_dropout_prob is not None
846
+ else self.config.hidden_dropout_prob
847
+ )
848
+ self.dropout = nn.Dropout(rate=classifier_dropout)
849
+ self.classifier = nn.Dense(
850
+ self.config.num_labels,
851
+ dtype=self.dtype,
852
+ )
853
+
854
+ def __call__(
855
+ self,
856
+ input_ids,
857
+ attention_mask,
858
+ token_type_ids,
859
+ position_ids,
860
+ deterministic: bool = True,
861
+ output_attentions: bool = False,
862
+ output_hidden_states: bool = False,
863
+ return_dict: bool = True,
864
+ ):
865
+ # Model
866
+ outputs = self.albert(
867
+ input_ids,
868
+ attention_mask,
869
+ token_type_ids,
870
+ position_ids,
871
+ deterministic=deterministic,
872
+ output_attentions=output_attentions,
873
+ output_hidden_states=output_hidden_states,
874
+ return_dict=return_dict,
875
+ )
876
+
877
+ pooled_output = outputs[1]
878
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
879
+ logits = self.classifier(pooled_output)
880
+
881
+ if not return_dict:
882
+ return (logits,) + outputs[2:]
883
+
884
+ return FlaxSequenceClassifierOutput(
885
+ logits=logits,
886
+ hidden_states=outputs.hidden_states,
887
+ attentions=outputs.attentions,
888
+ )
889
+
890
+
891
+ @add_start_docstrings(
892
+ """
893
+ Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
894
+ output) e.g. for GLUE tasks.
895
+ """,
896
+ ALBERT_START_DOCSTRING,
897
+ )
898
+ class FlaxAlbertForSequenceClassification(FlaxAlbertPreTrainedModel):
899
+ module_class = FlaxAlbertForSequenceClassificationModule
900
+
901
+
902
+ append_call_sample_docstring(
903
+ FlaxAlbertForSequenceClassification,
904
+ _CHECKPOINT_FOR_DOC,
905
+ FlaxSequenceClassifierOutput,
906
+ _CONFIG_FOR_DOC,
907
+ )
908
+
909
+
910
+ class FlaxAlbertForMultipleChoiceModule(nn.Module):
911
+ config: AlbertConfig
912
+ dtype: jnp.dtype = jnp.float32
913
+
914
+ def setup(self):
915
+ self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
916
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
917
+ self.classifier = nn.Dense(1, dtype=self.dtype)
918
+
919
+ def __call__(
920
+ self,
921
+ input_ids,
922
+ attention_mask,
923
+ token_type_ids,
924
+ position_ids,
925
+ deterministic: bool = True,
926
+ output_attentions: bool = False,
927
+ output_hidden_states: bool = False,
928
+ return_dict: bool = True,
929
+ ):
930
+ num_choices = input_ids.shape[1]
931
+ input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
932
+ attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
933
+ token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
934
+ position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
935
+
936
+ # Model
937
+ outputs = self.albert(
938
+ input_ids,
939
+ attention_mask,
940
+ token_type_ids,
941
+ position_ids,
942
+ deterministic=deterministic,
943
+ output_attentions=output_attentions,
944
+ output_hidden_states=output_hidden_states,
945
+ return_dict=return_dict,
946
+ )
947
+
948
+ pooled_output = outputs[1]
949
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
950
+ logits = self.classifier(pooled_output)
951
+
952
+ reshaped_logits = logits.reshape(-1, num_choices)
953
+
954
+ if not return_dict:
955
+ return (reshaped_logits,) + outputs[2:]
956
+
957
+ return FlaxMultipleChoiceModelOutput(
958
+ logits=reshaped_logits,
959
+ hidden_states=outputs.hidden_states,
960
+ attentions=outputs.attentions,
961
+ )
962
+
963
+
964
+ @add_start_docstrings(
965
+ """
966
+ Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
967
+ softmax) e.g. for RocStories/SWAG tasks.
968
+ """,
969
+ ALBERT_START_DOCSTRING,
970
+ )
971
+ class FlaxAlbertForMultipleChoice(FlaxAlbertPreTrainedModel):
972
+ module_class = FlaxAlbertForMultipleChoiceModule
973
+
974
+
975
+ overwrite_call_docstring(
976
+ FlaxAlbertForMultipleChoice, ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
977
+ )
978
+ append_call_sample_docstring(
979
+ FlaxAlbertForMultipleChoice,
980
+ _CHECKPOINT_FOR_DOC,
981
+ FlaxMultipleChoiceModelOutput,
982
+ _CONFIG_FOR_DOC,
983
+ )
984
+
985
+
986
+ class FlaxAlbertForTokenClassificationModule(nn.Module):
987
+ config: AlbertConfig
988
+ dtype: jnp.dtype = jnp.float32
989
+
990
+ def setup(self):
991
+ self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
992
+ classifier_dropout = (
993
+ self.config.classifier_dropout_prob
994
+ if self.config.classifier_dropout_prob is not None
995
+ else self.config.hidden_dropout_prob
996
+ )
997
+ self.dropout = nn.Dropout(rate=classifier_dropout)
998
+ self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
999
+
1000
+ def __call__(
1001
+ self,
1002
+ input_ids,
1003
+ attention_mask,
1004
+ token_type_ids,
1005
+ position_ids,
1006
+ deterministic: bool = True,
1007
+ output_attentions: bool = False,
1008
+ output_hidden_states: bool = False,
1009
+ return_dict: bool = True,
1010
+ ):
1011
+ # Model
1012
+ outputs = self.albert(
1013
+ input_ids,
1014
+ attention_mask,
1015
+ token_type_ids,
1016
+ position_ids,
1017
+ deterministic=deterministic,
1018
+ output_attentions=output_attentions,
1019
+ output_hidden_states=output_hidden_states,
1020
+ return_dict=return_dict,
1021
+ )
1022
+
1023
+ hidden_states = outputs[0]
1024
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
1025
+ logits = self.classifier(hidden_states)
1026
+
1027
+ if not return_dict:
1028
+ return (logits,) + outputs[1:]
1029
+
1030
+ return FlaxTokenClassifierOutput(
1031
+ logits=logits,
1032
+ hidden_states=outputs.hidden_states,
1033
+ attentions=outputs.attentions,
1034
+ )
1035
+
1036
+
1037
+ @add_start_docstrings(
1038
+ """
1039
+ Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1040
+ Named-Entity-Recognition (NER) tasks.
1041
+ """,
1042
+ ALBERT_START_DOCSTRING,
1043
+ )
1044
+ class FlaxAlbertForTokenClassification(FlaxAlbertPreTrainedModel):
1045
+ module_class = FlaxAlbertForTokenClassificationModule
1046
+
1047
+
1048
+ append_call_sample_docstring(
1049
+ FlaxAlbertForTokenClassification,
1050
+ _CHECKPOINT_FOR_DOC,
1051
+ FlaxTokenClassifierOutput,
1052
+ _CONFIG_FOR_DOC,
1053
+ )
1054
+
1055
+
1056
+ class FlaxAlbertForQuestionAnsweringModule(nn.Module):
1057
+ config: AlbertConfig
1058
+ dtype: jnp.dtype = jnp.float32
1059
+
1060
+ def setup(self):
1061
+ self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
1062
+ self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
1063
+
1064
+ def __call__(
1065
+ self,
1066
+ input_ids,
1067
+ attention_mask,
1068
+ token_type_ids,
1069
+ position_ids,
1070
+ deterministic: bool = True,
1071
+ output_attentions: bool = False,
1072
+ output_hidden_states: bool = False,
1073
+ return_dict: bool = True,
1074
+ ):
1075
+ # Model
1076
+ outputs = self.albert(
1077
+ input_ids,
1078
+ attention_mask,
1079
+ token_type_ids,
1080
+ position_ids,
1081
+ deterministic=deterministic,
1082
+ output_attentions=output_attentions,
1083
+ output_hidden_states=output_hidden_states,
1084
+ return_dict=return_dict,
1085
+ )
1086
+
1087
+ hidden_states = outputs[0]
1088
+
1089
+ logits = self.qa_outputs(hidden_states)
1090
+ start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
1091
+ start_logits = start_logits.squeeze(-1)
1092
+ end_logits = end_logits.squeeze(-1)
1093
+
1094
+ if not return_dict:
1095
+ return (start_logits, end_logits) + outputs[1:]
1096
+
1097
+ return FlaxQuestionAnsweringModelOutput(
1098
+ start_logits=start_logits,
1099
+ end_logits=end_logits,
1100
+ hidden_states=outputs.hidden_states,
1101
+ attentions=outputs.attentions,
1102
+ )
1103
+
1104
+
1105
+ @add_start_docstrings(
1106
+ """
1107
+ Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1108
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1109
+ """,
1110
+ ALBERT_START_DOCSTRING,
1111
+ )
1112
+ class FlaxAlbertForQuestionAnswering(FlaxAlbertPreTrainedModel):
1113
+ module_class = FlaxAlbertForQuestionAnsweringModule
1114
+
1115
+
1116
+ append_call_sample_docstring(
1117
+ FlaxAlbertForQuestionAnswering,
1118
+ _CHECKPOINT_FOR_DOC,
1119
+ FlaxQuestionAnsweringModelOutput,
1120
+ _CONFIG_FOR_DOC,
1121
+ )
1122
+
1123
+ __all__ = [
1124
+ "FlaxAlbertPreTrainedModel",
1125
+ "FlaxAlbertModel",
1126
+ "FlaxAlbertForPreTraining",
1127
+ "FlaxAlbertForMaskedLM",
1128
+ "FlaxAlbertForSequenceClassification",
1129
+ "FlaxAlbertForMultipleChoice",
1130
+ "FlaxAlbertForTokenClassification",
1131
+ "FlaxAlbertForQuestionAnswering",
1132
+ ]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_tf_albert.py ADDED
@@ -0,0 +1,1572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """TF 2.0 ALBERT model."""
17
+
18
+ from __future__ import annotations
19
+
20
+ import math
21
+ from dataclasses import dataclass
22
+
23
+ import numpy as np
24
+ import tensorflow as tf
25
+
26
+ from ...activations_tf import get_tf_activation
27
+ from ...modeling_tf_outputs import (
28
+ TFBaseModelOutput,
29
+ TFBaseModelOutputWithPooling,
30
+ TFMaskedLMOutput,
31
+ TFMultipleChoiceModelOutput,
32
+ TFQuestionAnsweringModelOutput,
33
+ TFSequenceClassifierOutput,
34
+ TFTokenClassifierOutput,
35
+ )
36
+ from ...modeling_tf_utils import (
37
+ TFMaskedLanguageModelingLoss,
38
+ TFModelInputType,
39
+ TFMultipleChoiceLoss,
40
+ TFPreTrainedModel,
41
+ TFQuestionAnsweringLoss,
42
+ TFSequenceClassificationLoss,
43
+ TFTokenClassificationLoss,
44
+ get_initializer,
45
+ keras,
46
+ keras_serializable,
47
+ unpack_inputs,
48
+ )
49
+ from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
50
+ from ...utils import (
51
+ ModelOutput,
52
+ add_code_sample_docstrings,
53
+ add_start_docstrings,
54
+ add_start_docstrings_to_model_forward,
55
+ logging,
56
+ replace_return_docstrings,
57
+ )
58
+ from .configuration_albert import AlbertConfig
59
+
60
+
61
+ logger = logging.get_logger(__name__)
62
+
63
+ _CHECKPOINT_FOR_DOC = "albert/albert-base-v2"
64
+ _CONFIG_FOR_DOC = "AlbertConfig"
65
+
66
+
67
+ class TFAlbertPreTrainingLoss:
68
+ """
69
+ Loss function suitable for ALBERT pretraining, that is, the task of pretraining a language model by combining SOP +
70
+ MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
71
+ """
72
+
73
+ def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
74
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
75
+ if self.config.tf_legacy_loss:
76
+ # make sure only labels that are not equal to -100
77
+ # are taken into account as loss
78
+ masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100)
79
+ masked_lm_reduced_logits = tf.boolean_mask(
80
+ tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])),
81
+ mask=masked_lm_active_loss,
82
+ )
83
+ masked_lm_labels = tf.boolean_mask(
84
+ tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss
85
+ )
86
+ sentence_order_active_loss = tf.not_equal(
87
+ tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100
88
+ )
89
+ sentence_order_reduced_logits = tf.boolean_mask(
90
+ tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss
91
+ )
92
+ sentence_order_label = tf.boolean_mask(
93
+ tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss
94
+ )
95
+ masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits)
96
+ sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits)
97
+ masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0]))
98
+ masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0)
99
+
100
+ return masked_lm_loss + sentence_order_loss
101
+
102
+ # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
103
+ unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
104
+ # make sure only labels that are not equal to -100
105
+ # are taken into account for the loss computation
106
+ lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
107
+ masked_lm_losses = unmasked_lm_losses * lm_loss_mask
108
+ reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask)
109
+
110
+ sop_logits = tf.reshape(logits[1], (-1, 2))
111
+ # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
112
+ unmasked_sop_loss = loss_fn(y_true=tf.nn.relu(labels["sentence_order_label"]), y_pred=sop_logits)
113
+ sop_loss_mask = tf.cast(labels["sentence_order_label"] != -100, dtype=unmasked_sop_loss.dtype)
114
+
115
+ masked_sop_loss = unmasked_sop_loss * sop_loss_mask
116
+ reduced_masked_sop_loss = tf.reduce_sum(masked_sop_loss) / tf.reduce_sum(sop_loss_mask)
117
+
118
+ return tf.reshape(reduced_masked_lm_loss + reduced_masked_sop_loss, (1,))
119
+
120
+
121
+ class TFAlbertEmbeddings(keras.layers.Layer):
122
+ """Construct the embeddings from word, position and token_type embeddings."""
123
+
124
+ def __init__(self, config: AlbertConfig, **kwargs):
125
+ super().__init__(**kwargs)
126
+
127
+ self.config = config
128
+ self.embedding_size = config.embedding_size
129
+ self.max_position_embeddings = config.max_position_embeddings
130
+ self.initializer_range = config.initializer_range
131
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
132
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
133
+
134
+ def build(self, input_shape=None):
135
+ with tf.name_scope("word_embeddings"):
136
+ self.weight = self.add_weight(
137
+ name="weight",
138
+ shape=[self.config.vocab_size, self.embedding_size],
139
+ initializer=get_initializer(self.initializer_range),
140
+ )
141
+
142
+ with tf.name_scope("token_type_embeddings"):
143
+ self.token_type_embeddings = self.add_weight(
144
+ name="embeddings",
145
+ shape=[self.config.type_vocab_size, self.embedding_size],
146
+ initializer=get_initializer(self.initializer_range),
147
+ )
148
+
149
+ with tf.name_scope("position_embeddings"):
150
+ self.position_embeddings = self.add_weight(
151
+ name="embeddings",
152
+ shape=[self.max_position_embeddings, self.embedding_size],
153
+ initializer=get_initializer(self.initializer_range),
154
+ )
155
+
156
+ if self.built:
157
+ return
158
+ self.built = True
159
+ if getattr(self, "LayerNorm", None) is not None:
160
+ with tf.name_scope(self.LayerNorm.name):
161
+ self.LayerNorm.build([None, None, self.config.embedding_size])
162
+
163
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
164
+ def call(
165
+ self,
166
+ input_ids: tf.Tensor | None = None,
167
+ position_ids: tf.Tensor | None = None,
168
+ token_type_ids: tf.Tensor | None = None,
169
+ inputs_embeds: tf.Tensor | None = None,
170
+ past_key_values_length=0,
171
+ training: bool = False,
172
+ ) -> tf.Tensor:
173
+ """
174
+ Applies embedding based on inputs tensor.
175
+
176
+ Returns:
177
+ final_embeddings (`tf.Tensor`): output embedding tensor.
178
+ """
179
+ if input_ids is None and inputs_embeds is None:
180
+ raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
181
+
182
+ if input_ids is not None:
183
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
184
+ inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
185
+
186
+ input_shape = shape_list(inputs_embeds)[:-1]
187
+
188
+ if token_type_ids is None:
189
+ token_type_ids = tf.fill(dims=input_shape, value=0)
190
+
191
+ if position_ids is None:
192
+ position_ids = tf.expand_dims(
193
+ tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
194
+ )
195
+
196
+ position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
197
+ token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
198
+ final_embeddings = inputs_embeds + position_embeds + token_type_embeds
199
+ final_embeddings = self.LayerNorm(inputs=final_embeddings)
200
+ final_embeddings = self.dropout(inputs=final_embeddings, training=training)
201
+
202
+ return final_embeddings
203
+
204
+
205
+ class TFAlbertAttention(keras.layers.Layer):
206
+ """Contains the complete attention sublayer, including both dropouts and layer norm."""
207
+
208
+ def __init__(self, config: AlbertConfig, **kwargs):
209
+ super().__init__(**kwargs)
210
+
211
+ if config.hidden_size % config.num_attention_heads != 0:
212
+ raise ValueError(
213
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number "
214
+ f"of attention heads ({config.num_attention_heads})"
215
+ )
216
+
217
+ self.num_attention_heads = config.num_attention_heads
218
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
219
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
220
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
221
+ self.output_attentions = config.output_attentions
222
+
223
+ self.query = keras.layers.Dense(
224
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
225
+ )
226
+ self.key = keras.layers.Dense(
227
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
228
+ )
229
+ self.value = keras.layers.Dense(
230
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
231
+ )
232
+ self.dense = keras.layers.Dense(
233
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
234
+ )
235
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
236
+ # Two different dropout probabilities; see https://github.com/google-research/albert/blob/master/modeling.py#L971-L993
237
+ self.attention_dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
238
+ self.output_dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
239
+ self.config = config
240
+
241
+ def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
242
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
243
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
244
+
245
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
246
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
247
+
248
+ def call(
249
+ self,
250
+ input_tensor: tf.Tensor,
251
+ attention_mask: tf.Tensor,
252
+ head_mask: tf.Tensor,
253
+ output_attentions: bool,
254
+ training: bool = False,
255
+ ) -> tuple[tf.Tensor]:
256
+ batch_size = shape_list(input_tensor)[0]
257
+ mixed_query_layer = self.query(inputs=input_tensor)
258
+ mixed_key_layer = self.key(inputs=input_tensor)
259
+ mixed_value_layer = self.value(inputs=input_tensor)
260
+ query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
261
+ key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
262
+ value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
263
+
264
+ # Take the dot product between "query" and "key" to get the raw attention scores.
265
+ # (batch size, num_heads, seq_len_q, seq_len_k)
266
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
267
+ dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
268
+ attention_scores = tf.divide(attention_scores, dk)
269
+
270
+ if attention_mask is not None:
271
+ # Apply the attention mask is (precomputed for all layers in TFAlbertModel call() function)
272
+ attention_scores = tf.add(attention_scores, attention_mask)
273
+
274
+ # Normalize the attention scores to probabilities.
275
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
276
+
277
+ # This is actually dropping out entire tokens to attend to, which might
278
+ # seem a bit unusual, but is taken from the original Transformer paper.
279
+ attention_probs = self.attention_dropout(inputs=attention_probs, training=training)
280
+
281
+ # Mask heads if we want to
282
+ if head_mask is not None:
283
+ attention_probs = tf.multiply(attention_probs, head_mask)
284
+
285
+ context_layer = tf.matmul(attention_probs, value_layer)
286
+ context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
287
+
288
+ # (batch_size, seq_len_q, all_head_size)
289
+ context_layer = tf.reshape(tensor=context_layer, shape=(batch_size, -1, self.all_head_size))
290
+ self_outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
291
+ hidden_states = self_outputs[0]
292
+ hidden_states = self.dense(inputs=hidden_states)
293
+ hidden_states = self.output_dropout(inputs=hidden_states, training=training)
294
+ attention_output = self.LayerNorm(inputs=hidden_states + input_tensor)
295
+
296
+ # add attentions if we output them
297
+ outputs = (attention_output,) + self_outputs[1:]
298
+
299
+ return outputs
300
+
301
+ def build(self, input_shape=None):
302
+ if self.built:
303
+ return
304
+ self.built = True
305
+ if getattr(self, "query", None) is not None:
306
+ with tf.name_scope(self.query.name):
307
+ self.query.build([None, None, self.config.hidden_size])
308
+ if getattr(self, "key", None) is not None:
309
+ with tf.name_scope(self.key.name):
310
+ self.key.build([None, None, self.config.hidden_size])
311
+ if getattr(self, "value", None) is not None:
312
+ with tf.name_scope(self.value.name):
313
+ self.value.build([None, None, self.config.hidden_size])
314
+ if getattr(self, "dense", None) is not None:
315
+ with tf.name_scope(self.dense.name):
316
+ self.dense.build([None, None, self.config.hidden_size])
317
+ if getattr(self, "LayerNorm", None) is not None:
318
+ with tf.name_scope(self.LayerNorm.name):
319
+ self.LayerNorm.build([None, None, self.config.hidden_size])
320
+
321
+
322
+ class TFAlbertLayer(keras.layers.Layer):
323
+ def __init__(self, config: AlbertConfig, **kwargs):
324
+ super().__init__(**kwargs)
325
+
326
+ self.attention = TFAlbertAttention(config, name="attention")
327
+ self.ffn = keras.layers.Dense(
328
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn"
329
+ )
330
+
331
+ if isinstance(config.hidden_act, str):
332
+ self.activation = get_tf_activation(config.hidden_act)
333
+ else:
334
+ self.activation = config.hidden_act
335
+
336
+ self.ffn_output = keras.layers.Dense(
337
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn_output"
338
+ )
339
+ self.full_layer_layer_norm = keras.layers.LayerNormalization(
340
+ epsilon=config.layer_norm_eps, name="full_layer_layer_norm"
341
+ )
342
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
343
+ self.config = config
344
+
345
+ def call(
346
+ self,
347
+ hidden_states: tf.Tensor,
348
+ attention_mask: tf.Tensor,
349
+ head_mask: tf.Tensor,
350
+ output_attentions: bool,
351
+ training: bool = False,
352
+ ) -> tuple[tf.Tensor]:
353
+ attention_outputs = self.attention(
354
+ input_tensor=hidden_states,
355
+ attention_mask=attention_mask,
356
+ head_mask=head_mask,
357
+ output_attentions=output_attentions,
358
+ training=training,
359
+ )
360
+ ffn_output = self.ffn(inputs=attention_outputs[0])
361
+ ffn_output = self.activation(ffn_output)
362
+ ffn_output = self.ffn_output(inputs=ffn_output)
363
+ ffn_output = self.dropout(inputs=ffn_output, training=training)
364
+ hidden_states = self.full_layer_layer_norm(inputs=ffn_output + attention_outputs[0])
365
+
366
+ # add attentions if we output them
367
+ outputs = (hidden_states,) + attention_outputs[1:]
368
+
369
+ return outputs
370
+
371
+ def build(self, input_shape=None):
372
+ if self.built:
373
+ return
374
+ self.built = True
375
+ if getattr(self, "attention", None) is not None:
376
+ with tf.name_scope(self.attention.name):
377
+ self.attention.build(None)
378
+ if getattr(self, "ffn", None) is not None:
379
+ with tf.name_scope(self.ffn.name):
380
+ self.ffn.build([None, None, self.config.hidden_size])
381
+ if getattr(self, "ffn_output", None) is not None:
382
+ with tf.name_scope(self.ffn_output.name):
383
+ self.ffn_output.build([None, None, self.config.intermediate_size])
384
+ if getattr(self, "full_layer_layer_norm", None) is not None:
385
+ with tf.name_scope(self.full_layer_layer_norm.name):
386
+ self.full_layer_layer_norm.build([None, None, self.config.hidden_size])
387
+
388
+
389
+ class TFAlbertLayerGroup(keras.layers.Layer):
390
+ def __init__(self, config: AlbertConfig, **kwargs):
391
+ super().__init__(**kwargs)
392
+
393
+ self.albert_layers = [
394
+ TFAlbertLayer(config, name=f"albert_layers_._{i}") for i in range(config.inner_group_num)
395
+ ]
396
+
397
+ def call(
398
+ self,
399
+ hidden_states: tf.Tensor,
400
+ attention_mask: tf.Tensor,
401
+ head_mask: tf.Tensor,
402
+ output_attentions: bool,
403
+ output_hidden_states: bool,
404
+ training: bool = False,
405
+ ) -> TFBaseModelOutput | tuple[tf.Tensor]:
406
+ layer_hidden_states = () if output_hidden_states else None
407
+ layer_attentions = () if output_attentions else None
408
+
409
+ for layer_index, albert_layer in enumerate(self.albert_layers):
410
+ if output_hidden_states:
411
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
412
+
413
+ layer_output = albert_layer(
414
+ hidden_states=hidden_states,
415
+ attention_mask=attention_mask,
416
+ head_mask=head_mask[layer_index],
417
+ output_attentions=output_attentions,
418
+ training=training,
419
+ )
420
+ hidden_states = layer_output[0]
421
+
422
+ if output_attentions:
423
+ layer_attentions = layer_attentions + (layer_output[1],)
424
+
425
+ # Add last layer
426
+ if output_hidden_states:
427
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
428
+
429
+ return tuple(v for v in [hidden_states, layer_hidden_states, layer_attentions] if v is not None)
430
+
431
+ def build(self, input_shape=None):
432
+ if self.built:
433
+ return
434
+ self.built = True
435
+ if getattr(self, "albert_layers", None) is not None:
436
+ for layer in self.albert_layers:
437
+ with tf.name_scope(layer.name):
438
+ layer.build(None)
439
+
440
+
441
+ class TFAlbertTransformer(keras.layers.Layer):
442
+ def __init__(self, config: AlbertConfig, **kwargs):
443
+ super().__init__(**kwargs)
444
+
445
+ self.num_hidden_layers = config.num_hidden_layers
446
+ self.num_hidden_groups = config.num_hidden_groups
447
+ # Number of layers in a hidden group
448
+ self.layers_per_group = int(config.num_hidden_layers / config.num_hidden_groups)
449
+ self.embedding_hidden_mapping_in = keras.layers.Dense(
450
+ units=config.hidden_size,
451
+ kernel_initializer=get_initializer(config.initializer_range),
452
+ name="embedding_hidden_mapping_in",
453
+ )
454
+ self.albert_layer_groups = [
455
+ TFAlbertLayerGroup(config, name=f"albert_layer_groups_._{i}") for i in range(config.num_hidden_groups)
456
+ ]
457
+ self.config = config
458
+
459
+ def call(
460
+ self,
461
+ hidden_states: tf.Tensor,
462
+ attention_mask: tf.Tensor,
463
+ head_mask: tf.Tensor,
464
+ output_attentions: bool,
465
+ output_hidden_states: bool,
466
+ return_dict: bool,
467
+ training: bool = False,
468
+ ) -> TFBaseModelOutput | tuple[tf.Tensor]:
469
+ hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states)
470
+ all_attentions = () if output_attentions else None
471
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
472
+
473
+ for i in range(self.num_hidden_layers):
474
+ # Index of the hidden group
475
+ group_idx = int(i / (self.num_hidden_layers / self.num_hidden_groups))
476
+ layer_group_output = self.albert_layer_groups[group_idx](
477
+ hidden_states=hidden_states,
478
+ attention_mask=attention_mask,
479
+ head_mask=head_mask[group_idx * self.layers_per_group : (group_idx + 1) * self.layers_per_group],
480
+ output_attentions=output_attentions,
481
+ output_hidden_states=output_hidden_states,
482
+ training=training,
483
+ )
484
+ hidden_states = layer_group_output[0]
485
+
486
+ if output_attentions:
487
+ all_attentions = all_attentions + layer_group_output[-1]
488
+
489
+ if output_hidden_states:
490
+ all_hidden_states = all_hidden_states + (hidden_states,)
491
+
492
+ if not return_dict:
493
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
494
+
495
+ return TFBaseModelOutput(
496
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
497
+ )
498
+
499
+ def build(self, input_shape=None):
500
+ if self.built:
501
+ return
502
+ self.built = True
503
+ if getattr(self, "embedding_hidden_mapping_in", None) is not None:
504
+ with tf.name_scope(self.embedding_hidden_mapping_in.name):
505
+ self.embedding_hidden_mapping_in.build([None, None, self.config.embedding_size])
506
+ if getattr(self, "albert_layer_groups", None) is not None:
507
+ for layer in self.albert_layer_groups:
508
+ with tf.name_scope(layer.name):
509
+ layer.build(None)
510
+
511
+
512
+ class TFAlbertPreTrainedModel(TFPreTrainedModel):
513
+ """
514
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
515
+ models.
516
+ """
517
+
518
+ config_class = AlbertConfig
519
+ base_model_prefix = "albert"
520
+
521
+
522
+ class TFAlbertMLMHead(keras.layers.Layer):
523
+ def __init__(self, config: AlbertConfig, input_embeddings: keras.layers.Layer, **kwargs):
524
+ super().__init__(**kwargs)
525
+
526
+ self.config = config
527
+ self.embedding_size = config.embedding_size
528
+ self.dense = keras.layers.Dense(
529
+ config.embedding_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
530
+ )
531
+ if isinstance(config.hidden_act, str):
532
+ self.activation = get_tf_activation(config.hidden_act)
533
+ else:
534
+ self.activation = config.hidden_act
535
+
536
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
537
+
538
+ # The output weights are the same as the input embeddings, but there is
539
+ # an output-only bias for each token.
540
+ self.decoder = input_embeddings
541
+
542
+ def build(self, input_shape=None):
543
+ self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
544
+ self.decoder_bias = self.add_weight(
545
+ shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias"
546
+ )
547
+
548
+ if self.built:
549
+ return
550
+ self.built = True
551
+ if getattr(self, "dense", None) is not None:
552
+ with tf.name_scope(self.dense.name):
553
+ self.dense.build([None, None, self.config.hidden_size])
554
+ if getattr(self, "LayerNorm", None) is not None:
555
+ with tf.name_scope(self.LayerNorm.name):
556
+ self.LayerNorm.build([None, None, self.config.embedding_size])
557
+
558
+ def get_output_embeddings(self) -> keras.layers.Layer:
559
+ return self.decoder
560
+
561
+ def set_output_embeddings(self, value: tf.Variable):
562
+ self.decoder.weight = value
563
+ self.decoder.vocab_size = shape_list(value)[0]
564
+
565
+ def get_bias(self) -> dict[str, tf.Variable]:
566
+ return {"bias": self.bias, "decoder_bias": self.decoder_bias}
567
+
568
+ def set_bias(self, value: tf.Variable):
569
+ self.bias = value["bias"]
570
+ self.decoder_bias = value["decoder_bias"]
571
+ self.config.vocab_size = shape_list(value["bias"])[0]
572
+
573
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
574
+ hidden_states = self.dense(inputs=hidden_states)
575
+ hidden_states = self.activation(hidden_states)
576
+ hidden_states = self.LayerNorm(inputs=hidden_states)
577
+ seq_length = shape_list(tensor=hidden_states)[1]
578
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
579
+ hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True)
580
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
581
+ hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.decoder_bias)
582
+
583
+ return hidden_states
584
+
585
+
586
+ @keras_serializable
587
+ class TFAlbertMainLayer(keras.layers.Layer):
588
+ config_class = AlbertConfig
589
+
590
+ def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True, **kwargs):
591
+ super().__init__(**kwargs)
592
+
593
+ self.config = config
594
+
595
+ self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
596
+ self.encoder = TFAlbertTransformer(config, name="encoder")
597
+ self.pooler = (
598
+ keras.layers.Dense(
599
+ units=config.hidden_size,
600
+ kernel_initializer=get_initializer(config.initializer_range),
601
+ activation="tanh",
602
+ name="pooler",
603
+ )
604
+ if add_pooling_layer
605
+ else None
606
+ )
607
+
608
+ def get_input_embeddings(self) -> keras.layers.Layer:
609
+ return self.embeddings
610
+
611
+ def set_input_embeddings(self, value: tf.Variable):
612
+ self.embeddings.weight = value
613
+ self.embeddings.vocab_size = shape_list(value)[0]
614
+
615
+ def _prune_heads(self, heads_to_prune):
616
+ """
617
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
618
+ class PreTrainedModel
619
+ """
620
+ raise NotImplementedError
621
+
622
+ @unpack_inputs
623
+ def call(
624
+ self,
625
+ input_ids: TFModelInputType | None = None,
626
+ attention_mask: np.ndarray | tf.Tensor | None = None,
627
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
628
+ position_ids: np.ndarray | tf.Tensor | None = None,
629
+ head_mask: np.ndarray | tf.Tensor | None = None,
630
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
631
+ output_attentions: bool | None = None,
632
+ output_hidden_states: bool | None = None,
633
+ return_dict: bool | None = None,
634
+ training: bool = False,
635
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
636
+ if input_ids is not None and inputs_embeds is not None:
637
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
638
+ elif input_ids is not None:
639
+ input_shape = shape_list(input_ids)
640
+ elif inputs_embeds is not None:
641
+ input_shape = shape_list(inputs_embeds)[:-1]
642
+ else:
643
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
644
+
645
+ if attention_mask is None:
646
+ attention_mask = tf.fill(dims=input_shape, value=1)
647
+
648
+ if token_type_ids is None:
649
+ token_type_ids = tf.fill(dims=input_shape, value=0)
650
+
651
+ embedding_output = self.embeddings(
652
+ input_ids=input_ids,
653
+ position_ids=position_ids,
654
+ token_type_ids=token_type_ids,
655
+ inputs_embeds=inputs_embeds,
656
+ training=training,
657
+ )
658
+
659
+ # We create a 3D attention mask from a 2D tensor mask.
660
+ # Sizes are [batch_size, 1, 1, to_seq_length]
661
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
662
+ # this attention mask is more simple than the triangular masking of causal attention
663
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
664
+ extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
665
+
666
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
667
+ # masked positions, this operation will create a tensor which is 0.0 for
668
+ # positions we want to attend and -10000.0 for masked positions.
669
+ # Since we are adding it to the raw scores before the softmax, this is
670
+ # effectively the same as removing these entirely.
671
+ extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
672
+ one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
673
+ ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
674
+ extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
675
+
676
+ # Prepare head mask if needed
677
+ # 1.0 in head_mask indicate we keep the head
678
+ # attention_probs has shape bsz x n_heads x N x N
679
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
680
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
681
+ if head_mask is not None:
682
+ raise NotImplementedError
683
+ else:
684
+ head_mask = [None] * self.config.num_hidden_layers
685
+
686
+ encoder_outputs = self.encoder(
687
+ hidden_states=embedding_output,
688
+ attention_mask=extended_attention_mask,
689
+ head_mask=head_mask,
690
+ output_attentions=output_attentions,
691
+ output_hidden_states=output_hidden_states,
692
+ return_dict=return_dict,
693
+ training=training,
694
+ )
695
+
696
+ sequence_output = encoder_outputs[0]
697
+ pooled_output = self.pooler(inputs=sequence_output[:, 0]) if self.pooler is not None else None
698
+
699
+ if not return_dict:
700
+ return (
701
+ sequence_output,
702
+ pooled_output,
703
+ ) + encoder_outputs[1:]
704
+
705
+ return TFBaseModelOutputWithPooling(
706
+ last_hidden_state=sequence_output,
707
+ pooler_output=pooled_output,
708
+ hidden_states=encoder_outputs.hidden_states,
709
+ attentions=encoder_outputs.attentions,
710
+ )
711
+
712
+ def build(self, input_shape=None):
713
+ if self.built:
714
+ return
715
+ self.built = True
716
+ if getattr(self, "embeddings", None) is not None:
717
+ with tf.name_scope(self.embeddings.name):
718
+ self.embeddings.build(None)
719
+ if getattr(self, "encoder", None) is not None:
720
+ with tf.name_scope(self.encoder.name):
721
+ self.encoder.build(None)
722
+ if getattr(self, "pooler", None) is not None:
723
+ with tf.name_scope(self.pooler.name):
724
+ self.pooler.build([None, None, self.config.hidden_size])
725
+
726
+
727
+ @dataclass
728
+ class TFAlbertForPreTrainingOutput(ModelOutput):
729
+ """
730
+ Output type of [`TFAlbertForPreTraining`].
731
+
732
+ Args:
733
+ prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
734
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
735
+ sop_logits (`tf.Tensor` of shape `(batch_size, 2)`):
736
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
737
+ before SoftMax).
738
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
739
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
740
+ `(batch_size, sequence_length, hidden_size)`.
741
+
742
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
743
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
744
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
745
+ sequence_length)`.
746
+
747
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
748
+ heads.
749
+ """
750
+
751
+ loss: tf.Tensor | None = None
752
+ prediction_logits: tf.Tensor | None = None
753
+ sop_logits: tf.Tensor | None = None
754
+ hidden_states: tuple[tf.Tensor] | None = None
755
+ attentions: tuple[tf.Tensor] | None = None
756
+
757
+
758
+ ALBERT_START_DOCSTRING = r"""
759
+
760
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
761
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
762
+ etc.)
763
+
764
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
765
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
766
+ behavior.
767
+
768
+ <Tip>
769
+
770
+ TensorFlow models and layers in `transformers` accept two formats as input:
771
+
772
+ - having all inputs as keyword arguments (like PyTorch models), or
773
+ - having all inputs as a list, tuple or dict in the first positional argument.
774
+
775
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
776
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
777
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
778
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
779
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
780
+ positional argument:
781
+
782
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
783
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
784
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
785
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
786
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
787
+
788
+ Note that when creating models and layers with
789
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
790
+ about any of this, as you can just pass inputs like you would to any other Python function!
791
+
792
+ </Tip>
793
+
794
+ Args:
795
+ config ([`AlbertConfig`]): Model configuration class with all the parameters of the model.
796
+ Initializing with a config file does not load the weights associated with the model, only the
797
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
798
+ """
799
+
800
+ ALBERT_INPUTS_DOCSTRING = r"""
801
+ Args:
802
+ input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
803
+ Indices of input sequence tokens in the vocabulary.
804
+
805
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
806
+ [`PreTrainedTokenizer.encode`] for details.
807
+
808
+ [What are input IDs?](../glossary#input-ids)
809
+ attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
810
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
811
+
812
+ - 1 for tokens that are **not masked**,
813
+ - 0 for tokens that are **masked**.
814
+
815
+ [What are attention masks?](../glossary#attention-mask)
816
+ token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
817
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
818
+ 1]`:
819
+
820
+ - 0 corresponds to a *sentence A* token,
821
+ - 1 corresponds to a *sentence B* token.
822
+
823
+ [What are token type IDs?](../glossary#token-type-ids)
824
+ position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
825
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
826
+ config.max_position_embeddings - 1]`.
827
+
828
+ [What are position IDs?](../glossary#position-ids)
829
+ head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
830
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
831
+
832
+ - 1 indicates the head is **not masked**,
833
+ - 0 indicates the head is **masked**.
834
+
835
+ inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
836
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
837
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
838
+ model's internal embedding lookup matrix.
839
+ output_attentions (`bool`, *optional*):
840
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
841
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
842
+ config will be used instead.
843
+ output_hidden_states (`bool`, *optional*):
844
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
845
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
846
+ used instead.
847
+ return_dict (`bool`, *optional*):
848
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
849
+ eager mode, in graph mode the value will always be set to True.
850
+ training (`bool`, *optional*, defaults to `False`):
851
+ Whether or not to use the model in training mode (some modules like dropout modules have different
852
+ behaviors between training and evaluation).
853
+ """
854
+
855
+
856
+ @add_start_docstrings(
857
+ "The bare Albert Model transformer outputting raw hidden-states without any specific head on top.",
858
+ ALBERT_START_DOCSTRING,
859
+ )
860
+ class TFAlbertModel(TFAlbertPreTrainedModel):
861
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
862
+ super().__init__(config, *inputs, **kwargs)
863
+
864
+ self.albert = TFAlbertMainLayer(config, name="albert")
865
+
866
+ @unpack_inputs
867
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
868
+ @add_code_sample_docstrings(
869
+ checkpoint=_CHECKPOINT_FOR_DOC,
870
+ output_type=TFBaseModelOutputWithPooling,
871
+ config_class=_CONFIG_FOR_DOC,
872
+ )
873
+ def call(
874
+ self,
875
+ input_ids: TFModelInputType | None = None,
876
+ attention_mask: np.ndarray | tf.Tensor | None = None,
877
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
878
+ position_ids: np.ndarray | tf.Tensor | None = None,
879
+ head_mask: np.ndarray | tf.Tensor | None = None,
880
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
881
+ output_attentions: bool | None = None,
882
+ output_hidden_states: bool | None = None,
883
+ return_dict: bool | None = None,
884
+ training: bool | None = False,
885
+ ) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
886
+ outputs = self.albert(
887
+ input_ids=input_ids,
888
+ attention_mask=attention_mask,
889
+ token_type_ids=token_type_ids,
890
+ position_ids=position_ids,
891
+ head_mask=head_mask,
892
+ inputs_embeds=inputs_embeds,
893
+ output_attentions=output_attentions,
894
+ output_hidden_states=output_hidden_states,
895
+ return_dict=return_dict,
896
+ training=training,
897
+ )
898
+
899
+ return outputs
900
+
901
+ def build(self, input_shape=None):
902
+ if self.built:
903
+ return
904
+ self.built = True
905
+ if getattr(self, "albert", None) is not None:
906
+ with tf.name_scope(self.albert.name):
907
+ self.albert.build(None)
908
+
909
+
910
+ @add_start_docstrings(
911
+ """
912
+ Albert Model with two heads on top for pretraining: a `masked language modeling` head and a `sentence order
913
+ prediction` (classification) head.
914
+ """,
915
+ ALBERT_START_DOCSTRING,
916
+ )
917
+ class TFAlbertForPreTraining(TFAlbertPreTrainedModel, TFAlbertPreTrainingLoss):
918
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
919
+ _keys_to_ignore_on_load_unexpected = [r"predictions.decoder.weight"]
920
+
921
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
922
+ super().__init__(config, *inputs, **kwargs)
923
+
924
+ self.num_labels = config.num_labels
925
+
926
+ self.albert = TFAlbertMainLayer(config, name="albert")
927
+ self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions")
928
+ self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier")
929
+
930
+ def get_lm_head(self) -> keras.layers.Layer:
931
+ return self.predictions
932
+
933
+ @unpack_inputs
934
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
935
+ @replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
936
+ def call(
937
+ self,
938
+ input_ids: TFModelInputType | None = None,
939
+ attention_mask: np.ndarray | tf.Tensor | None = None,
940
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
941
+ position_ids: np.ndarray | tf.Tensor | None = None,
942
+ head_mask: np.ndarray | tf.Tensor | None = None,
943
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
944
+ output_attentions: bool | None = None,
945
+ output_hidden_states: bool | None = None,
946
+ return_dict: bool | None = None,
947
+ labels: np.ndarray | tf.Tensor | None = None,
948
+ sentence_order_label: np.ndarray | tf.Tensor | None = None,
949
+ training: bool | None = False,
950
+ ) -> TFAlbertForPreTrainingOutput | tuple[tf.Tensor]:
951
+ r"""
952
+ Return:
953
+
954
+ Example:
955
+
956
+ ```python
957
+ >>> import tensorflow as tf
958
+ >>> from transformers import AutoTokenizer, TFAlbertForPreTraining
959
+
960
+ >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
961
+ >>> model = TFAlbertForPreTraining.from_pretrained("albert/albert-base-v2")
962
+
963
+ >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]
964
+ >>> # Batch size 1
965
+ >>> outputs = model(input_ids)
966
+
967
+ >>> prediction_logits = outputs.prediction_logits
968
+ >>> sop_logits = outputs.sop_logits
969
+ ```"""
970
+
971
+ outputs = self.albert(
972
+ input_ids=input_ids,
973
+ attention_mask=attention_mask,
974
+ token_type_ids=token_type_ids,
975
+ position_ids=position_ids,
976
+ head_mask=head_mask,
977
+ inputs_embeds=inputs_embeds,
978
+ output_attentions=output_attentions,
979
+ output_hidden_states=output_hidden_states,
980
+ return_dict=return_dict,
981
+ training=training,
982
+ )
983
+ sequence_output, pooled_output = outputs[:2]
984
+ prediction_scores = self.predictions(hidden_states=sequence_output)
985
+ sop_scores = self.sop_classifier(pooled_output=pooled_output, training=training)
986
+ total_loss = None
987
+
988
+ if labels is not None and sentence_order_label is not None:
989
+ d_labels = {"labels": labels}
990
+ d_labels["sentence_order_label"] = sentence_order_label
991
+ total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, sop_scores))
992
+
993
+ if not return_dict:
994
+ output = (prediction_scores, sop_scores) + outputs[2:]
995
+ return ((total_loss,) + output) if total_loss is not None else output
996
+
997
+ return TFAlbertForPreTrainingOutput(
998
+ loss=total_loss,
999
+ prediction_logits=prediction_scores,
1000
+ sop_logits=sop_scores,
1001
+ hidden_states=outputs.hidden_states,
1002
+ attentions=outputs.attentions,
1003
+ )
1004
+
1005
+ def build(self, input_shape=None):
1006
+ if self.built:
1007
+ return
1008
+ self.built = True
1009
+ if getattr(self, "albert", None) is not None:
1010
+ with tf.name_scope(self.albert.name):
1011
+ self.albert.build(None)
1012
+ if getattr(self, "predictions", None) is not None:
1013
+ with tf.name_scope(self.predictions.name):
1014
+ self.predictions.build(None)
1015
+ if getattr(self, "sop_classifier", None) is not None:
1016
+ with tf.name_scope(self.sop_classifier.name):
1017
+ self.sop_classifier.build(None)
1018
+
1019
+
1020
+ class TFAlbertSOPHead(keras.layers.Layer):
1021
+ def __init__(self, config: AlbertConfig, **kwargs):
1022
+ super().__init__(**kwargs)
1023
+
1024
+ self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob)
1025
+ self.classifier = keras.layers.Dense(
1026
+ units=config.num_labels,
1027
+ kernel_initializer=get_initializer(config.initializer_range),
1028
+ name="classifier",
1029
+ )
1030
+ self.config = config
1031
+
1032
+ def call(self, pooled_output: tf.Tensor, training: bool) -> tf.Tensor:
1033
+ dropout_pooled_output = self.dropout(inputs=pooled_output, training=training)
1034
+ logits = self.classifier(inputs=dropout_pooled_output)
1035
+
1036
+ return logits
1037
+
1038
+ def build(self, input_shape=None):
1039
+ if self.built:
1040
+ return
1041
+ self.built = True
1042
+ if getattr(self, "classifier", None) is not None:
1043
+ with tf.name_scope(self.classifier.name):
1044
+ self.classifier.build([None, None, self.config.hidden_size])
1045
+
1046
+
1047
+ @add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING)
1048
+ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
1049
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1050
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions.decoder.weight"]
1051
+
1052
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
1053
+ super().__init__(config, *inputs, **kwargs)
1054
+
1055
+ self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
1056
+ self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions")
1057
+
1058
+ def get_lm_head(self) -> keras.layers.Layer:
1059
+ return self.predictions
1060
+
1061
+ @unpack_inputs
1062
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1063
+ @replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
1064
+ def call(
1065
+ self,
1066
+ input_ids: TFModelInputType | None = None,
1067
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1068
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1069
+ position_ids: np.ndarray | tf.Tensor | None = None,
1070
+ head_mask: np.ndarray | tf.Tensor | None = None,
1071
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1072
+ output_attentions: bool | None = None,
1073
+ output_hidden_states: bool | None = None,
1074
+ return_dict: bool | None = None,
1075
+ labels: np.ndarray | tf.Tensor | None = None,
1076
+ training: bool | None = False,
1077
+ ) -> TFMaskedLMOutput | tuple[tf.Tensor]:
1078
+ r"""
1079
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1080
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1081
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1082
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1083
+
1084
+ Returns:
1085
+
1086
+ Example:
1087
+
1088
+ ```python
1089
+ >>> import tensorflow as tf
1090
+ >>> from transformers import AutoTokenizer, TFAlbertForMaskedLM
1091
+
1092
+ >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
1093
+ >>> model = TFAlbertForMaskedLM.from_pretrained("albert/albert-base-v2")
1094
+
1095
+ >>> # add mask_token
1096
+ >>> inputs = tokenizer(f"The capital of [MASK] is Paris.", return_tensors="tf")
1097
+ >>> logits = model(**inputs).logits
1098
+
1099
+ >>> # retrieve index of [MASK]
1100
+ >>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1]
1101
+ >>> predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1)
1102
+ >>> tokenizer.decode(predicted_token_id)
1103
+ 'france'
1104
+ ```
1105
+
1106
+ ```python
1107
+ >>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
1108
+ >>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
1109
+ >>> outputs = model(**inputs, labels=labels)
1110
+ >>> round(float(outputs.loss), 2)
1111
+ 0.81
1112
+ ```
1113
+ """
1114
+ outputs = self.albert(
1115
+ input_ids=input_ids,
1116
+ attention_mask=attention_mask,
1117
+ token_type_ids=token_type_ids,
1118
+ position_ids=position_ids,
1119
+ head_mask=head_mask,
1120
+ inputs_embeds=inputs_embeds,
1121
+ output_attentions=output_attentions,
1122
+ output_hidden_states=output_hidden_states,
1123
+ return_dict=return_dict,
1124
+ training=training,
1125
+ )
1126
+ sequence_output = outputs[0]
1127
+ prediction_scores = self.predictions(hidden_states=sequence_output, training=training)
1128
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
1129
+
1130
+ if not return_dict:
1131
+ output = (prediction_scores,) + outputs[2:]
1132
+
1133
+ return ((loss,) + output) if loss is not None else output
1134
+
1135
+ return TFMaskedLMOutput(
1136
+ loss=loss,
1137
+ logits=prediction_scores,
1138
+ hidden_states=outputs.hidden_states,
1139
+ attentions=outputs.attentions,
1140
+ )
1141
+
1142
+ def build(self, input_shape=None):
1143
+ if self.built:
1144
+ return
1145
+ self.built = True
1146
+ if getattr(self, "albert", None) is not None:
1147
+ with tf.name_scope(self.albert.name):
1148
+ self.albert.build(None)
1149
+ if getattr(self, "predictions", None) is not None:
1150
+ with tf.name_scope(self.predictions.name):
1151
+ self.predictions.build(None)
1152
+
1153
+
1154
+ @add_start_docstrings(
1155
+ """
1156
+ Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1157
+ output) e.g. for GLUE tasks.
1158
+ """,
1159
+ ALBERT_START_DOCSTRING,
1160
+ )
1161
+ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClassificationLoss):
1162
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1163
+ _keys_to_ignore_on_load_unexpected = [r"predictions"]
1164
+ _keys_to_ignore_on_load_missing = [r"dropout"]
1165
+
1166
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
1167
+ super().__init__(config, *inputs, **kwargs)
1168
+
1169
+ self.num_labels = config.num_labels
1170
+
1171
+ self.albert = TFAlbertMainLayer(config, name="albert")
1172
+ self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob)
1173
+ self.classifier = keras.layers.Dense(
1174
+ units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
1175
+ )
1176
+ self.config = config
1177
+
1178
+ @unpack_inputs
1179
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1180
+ @add_code_sample_docstrings(
1181
+ checkpoint="vumichien/albert-base-v2-imdb",
1182
+ output_type=TFSequenceClassifierOutput,
1183
+ config_class=_CONFIG_FOR_DOC,
1184
+ expected_output="'LABEL_1'",
1185
+ expected_loss=0.12,
1186
+ )
1187
+ def call(
1188
+ self,
1189
+ input_ids: TFModelInputType | None = None,
1190
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1191
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1192
+ position_ids: np.ndarray | tf.Tensor | None = None,
1193
+ head_mask: np.ndarray | tf.Tensor | None = None,
1194
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1195
+ output_attentions: bool | None = None,
1196
+ output_hidden_states: bool | None = None,
1197
+ return_dict: bool | None = None,
1198
+ labels: np.ndarray | tf.Tensor | None = None,
1199
+ training: bool | None = False,
1200
+ ) -> TFSequenceClassifierOutput | tuple[tf.Tensor]:
1201
+ r"""
1202
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
1203
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1204
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1205
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1206
+ """
1207
+ outputs = self.albert(
1208
+ input_ids=input_ids,
1209
+ attention_mask=attention_mask,
1210
+ token_type_ids=token_type_ids,
1211
+ position_ids=position_ids,
1212
+ head_mask=head_mask,
1213
+ inputs_embeds=inputs_embeds,
1214
+ output_attentions=output_attentions,
1215
+ output_hidden_states=output_hidden_states,
1216
+ return_dict=return_dict,
1217
+ training=training,
1218
+ )
1219
+ pooled_output = outputs[1]
1220
+ pooled_output = self.dropout(inputs=pooled_output, training=training)
1221
+ logits = self.classifier(inputs=pooled_output)
1222
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
1223
+
1224
+ if not return_dict:
1225
+ output = (logits,) + outputs[2:]
1226
+
1227
+ return ((loss,) + output) if loss is not None else output
1228
+
1229
+ return TFSequenceClassifierOutput(
1230
+ loss=loss,
1231
+ logits=logits,
1232
+ hidden_states=outputs.hidden_states,
1233
+ attentions=outputs.attentions,
1234
+ )
1235
+
1236
+ def build(self, input_shape=None):
1237
+ if self.built:
1238
+ return
1239
+ self.built = True
1240
+ if getattr(self, "albert", None) is not None:
1241
+ with tf.name_scope(self.albert.name):
1242
+ self.albert.build(None)
1243
+ if getattr(self, "classifier", None) is not None:
1244
+ with tf.name_scope(self.classifier.name):
1245
+ self.classifier.build([None, None, self.config.hidden_size])
1246
+
1247
+
1248
+ @add_start_docstrings(
1249
+ """
1250
+ Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1251
+ Named-Entity-Recognition (NER) tasks.
1252
+ """,
1253
+ ALBERT_START_DOCSTRING,
1254
+ )
1255
+ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
1256
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1257
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
1258
+ _keys_to_ignore_on_load_missing = [r"dropout"]
1259
+
1260
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
1261
+ super().__init__(config, *inputs, **kwargs)
1262
+
1263
+ self.num_labels = config.num_labels
1264
+
1265
+ self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
1266
+ classifier_dropout_prob = (
1267
+ config.classifier_dropout_prob
1268
+ if config.classifier_dropout_prob is not None
1269
+ else config.hidden_dropout_prob
1270
+ )
1271
+ self.dropout = keras.layers.Dropout(rate=classifier_dropout_prob)
1272
+ self.classifier = keras.layers.Dense(
1273
+ units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
1274
+ )
1275
+ self.config = config
1276
+
1277
+ @unpack_inputs
1278
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1279
+ @add_code_sample_docstrings(
1280
+ checkpoint=_CHECKPOINT_FOR_DOC,
1281
+ output_type=TFTokenClassifierOutput,
1282
+ config_class=_CONFIG_FOR_DOC,
1283
+ )
1284
+ def call(
1285
+ self,
1286
+ input_ids: TFModelInputType | None = None,
1287
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1288
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1289
+ position_ids: np.ndarray | tf.Tensor | None = None,
1290
+ head_mask: np.ndarray | tf.Tensor | None = None,
1291
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1292
+ output_attentions: bool | None = None,
1293
+ output_hidden_states: bool | None = None,
1294
+ return_dict: bool | None = None,
1295
+ labels: np.ndarray | tf.Tensor | None = None,
1296
+ training: bool | None = False,
1297
+ ) -> TFTokenClassifierOutput | tuple[tf.Tensor]:
1298
+ r"""
1299
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1300
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1301
+ """
1302
+ outputs = self.albert(
1303
+ input_ids=input_ids,
1304
+ attention_mask=attention_mask,
1305
+ token_type_ids=token_type_ids,
1306
+ position_ids=position_ids,
1307
+ head_mask=head_mask,
1308
+ inputs_embeds=inputs_embeds,
1309
+ output_attentions=output_attentions,
1310
+ output_hidden_states=output_hidden_states,
1311
+ return_dict=return_dict,
1312
+ training=training,
1313
+ )
1314
+ sequence_output = outputs[0]
1315
+ sequence_output = self.dropout(inputs=sequence_output, training=training)
1316
+ logits = self.classifier(inputs=sequence_output)
1317
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
1318
+
1319
+ if not return_dict:
1320
+ output = (logits,) + outputs[2:]
1321
+
1322
+ return ((loss,) + output) if loss is not None else output
1323
+
1324
+ return TFTokenClassifierOutput(
1325
+ loss=loss,
1326
+ logits=logits,
1327
+ hidden_states=outputs.hidden_states,
1328
+ attentions=outputs.attentions,
1329
+ )
1330
+
1331
+ def build(self, input_shape=None):
1332
+ if self.built:
1333
+ return
1334
+ self.built = True
1335
+ if getattr(self, "albert", None) is not None:
1336
+ with tf.name_scope(self.albert.name):
1337
+ self.albert.build(None)
1338
+ if getattr(self, "classifier", None) is not None:
1339
+ with tf.name_scope(self.classifier.name):
1340
+ self.classifier.build([None, None, self.config.hidden_size])
1341
+
1342
+
1343
+ @add_start_docstrings(
1344
+ """
1345
+ Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1346
+ layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1347
+ """,
1348
+ ALBERT_START_DOCSTRING,
1349
+ )
1350
+ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
1351
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1352
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
1353
+
1354
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
1355
+ super().__init__(config, *inputs, **kwargs)
1356
+
1357
+ self.num_labels = config.num_labels
1358
+
1359
+ self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
1360
+ self.qa_outputs = keras.layers.Dense(
1361
+ units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
1362
+ )
1363
+ self.config = config
1364
+
1365
+ @unpack_inputs
1366
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1367
+ @add_code_sample_docstrings(
1368
+ checkpoint="vumichien/albert-base-v2-squad2",
1369
+ output_type=TFQuestionAnsweringModelOutput,
1370
+ config_class=_CONFIG_FOR_DOC,
1371
+ qa_target_start_index=12,
1372
+ qa_target_end_index=13,
1373
+ expected_output="'a nice puppet'",
1374
+ expected_loss=7.36,
1375
+ )
1376
+ def call(
1377
+ self,
1378
+ input_ids: TFModelInputType | None = None,
1379
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1380
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1381
+ position_ids: np.ndarray | tf.Tensor | None = None,
1382
+ head_mask: np.ndarray | tf.Tensor | None = None,
1383
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1384
+ output_attentions: bool | None = None,
1385
+ output_hidden_states: bool | None = None,
1386
+ return_dict: bool | None = None,
1387
+ start_positions: np.ndarray | tf.Tensor | None = None,
1388
+ end_positions: np.ndarray | tf.Tensor | None = None,
1389
+ training: bool | None = False,
1390
+ ) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]:
1391
+ r"""
1392
+ start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
1393
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1394
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1395
+ are not taken into account for computing the loss.
1396
+ end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
1397
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1398
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1399
+ are not taken into account for computing the loss.
1400
+ """
1401
+ outputs = self.albert(
1402
+ input_ids=input_ids,
1403
+ attention_mask=attention_mask,
1404
+ token_type_ids=token_type_ids,
1405
+ position_ids=position_ids,
1406
+ head_mask=head_mask,
1407
+ inputs_embeds=inputs_embeds,
1408
+ output_attentions=output_attentions,
1409
+ output_hidden_states=output_hidden_states,
1410
+ return_dict=return_dict,
1411
+ training=training,
1412
+ )
1413
+ sequence_output = outputs[0]
1414
+ logits = self.qa_outputs(inputs=sequence_output)
1415
+ start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
1416
+ start_logits = tf.squeeze(input=start_logits, axis=-1)
1417
+ end_logits = tf.squeeze(input=end_logits, axis=-1)
1418
+ loss = None
1419
+
1420
+ if start_positions is not None and end_positions is not None:
1421
+ labels = {"start_position": start_positions}
1422
+ labels["end_position"] = end_positions
1423
+ loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
1424
+
1425
+ if not return_dict:
1426
+ output = (start_logits, end_logits) + outputs[2:]
1427
+
1428
+ return ((loss,) + output) if loss is not None else output
1429
+
1430
+ return TFQuestionAnsweringModelOutput(
1431
+ loss=loss,
1432
+ start_logits=start_logits,
1433
+ end_logits=end_logits,
1434
+ hidden_states=outputs.hidden_states,
1435
+ attentions=outputs.attentions,
1436
+ )
1437
+
1438
+ def build(self, input_shape=None):
1439
+ if self.built:
1440
+ return
1441
+ self.built = True
1442
+ if getattr(self, "albert", None) is not None:
1443
+ with tf.name_scope(self.albert.name):
1444
+ self.albert.build(None)
1445
+ if getattr(self, "qa_outputs", None) is not None:
1446
+ with tf.name_scope(self.qa_outputs.name):
1447
+ self.qa_outputs.build([None, None, self.config.hidden_size])
1448
+
1449
+
1450
+ @add_start_docstrings(
1451
+ """
1452
+ Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1453
+ softmax) e.g. for RocStories/SWAG tasks.
1454
+ """,
1455
+ ALBERT_START_DOCSTRING,
1456
+ )
1457
+ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
1458
+ # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
1459
+ _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
1460
+ _keys_to_ignore_on_load_missing = [r"dropout"]
1461
+
1462
+ def __init__(self, config: AlbertConfig, *inputs, **kwargs):
1463
+ super().__init__(config, *inputs, **kwargs)
1464
+
1465
+ self.albert = TFAlbertMainLayer(config, name="albert")
1466
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
1467
+ self.classifier = keras.layers.Dense(
1468
+ units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
1469
+ )
1470
+ self.config = config
1471
+
1472
+ @unpack_inputs
1473
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1474
+ @add_code_sample_docstrings(
1475
+ checkpoint=_CHECKPOINT_FOR_DOC,
1476
+ output_type=TFMultipleChoiceModelOutput,
1477
+ config_class=_CONFIG_FOR_DOC,
1478
+ )
1479
+ def call(
1480
+ self,
1481
+ input_ids: TFModelInputType | None = None,
1482
+ attention_mask: np.ndarray | tf.Tensor | None = None,
1483
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
1484
+ position_ids: np.ndarray | tf.Tensor | None = None,
1485
+ head_mask: np.ndarray | tf.Tensor | None = None,
1486
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
1487
+ output_attentions: bool | None = None,
1488
+ output_hidden_states: bool | None = None,
1489
+ return_dict: bool | None = None,
1490
+ labels: np.ndarray | tf.Tensor | None = None,
1491
+ training: bool | None = False,
1492
+ ) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]:
1493
+ r"""
1494
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
1495
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
1496
+ where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
1497
+ """
1498
+
1499
+ if input_ids is not None:
1500
+ num_choices = shape_list(input_ids)[1]
1501
+ seq_length = shape_list(input_ids)[2]
1502
+ else:
1503
+ num_choices = shape_list(inputs_embeds)[1]
1504
+ seq_length = shape_list(inputs_embeds)[2]
1505
+
1506
+ flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
1507
+ flat_attention_mask = (
1508
+ tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None
1509
+ )
1510
+ flat_token_type_ids = (
1511
+ tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None
1512
+ )
1513
+ flat_position_ids = (
1514
+ tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None
1515
+ )
1516
+ flat_inputs_embeds = (
1517
+ tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3]))
1518
+ if inputs_embeds is not None
1519
+ else None
1520
+ )
1521
+ outputs = self.albert(
1522
+ input_ids=flat_input_ids,
1523
+ attention_mask=flat_attention_mask,
1524
+ token_type_ids=flat_token_type_ids,
1525
+ position_ids=flat_position_ids,
1526
+ head_mask=head_mask,
1527
+ inputs_embeds=flat_inputs_embeds,
1528
+ output_attentions=output_attentions,
1529
+ output_hidden_states=output_hidden_states,
1530
+ return_dict=return_dict,
1531
+ training=training,
1532
+ )
1533
+ pooled_output = outputs[1]
1534
+ pooled_output = self.dropout(inputs=pooled_output, training=training)
1535
+ logits = self.classifier(inputs=pooled_output)
1536
+ reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
1537
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
1538
+
1539
+ if not return_dict:
1540
+ output = (reshaped_logits,) + outputs[2:]
1541
+ return ((loss,) + output) if loss is not None else output
1542
+
1543
+ return TFMultipleChoiceModelOutput(
1544
+ loss=loss,
1545
+ logits=reshaped_logits,
1546
+ hidden_states=outputs.hidden_states,
1547
+ attentions=outputs.attentions,
1548
+ )
1549
+
1550
+ def build(self, input_shape=None):
1551
+ if self.built:
1552
+ return
1553
+ self.built = True
1554
+ if getattr(self, "albert", None) is not None:
1555
+ with tf.name_scope(self.albert.name):
1556
+ self.albert.build(None)
1557
+ if getattr(self, "classifier", None) is not None:
1558
+ with tf.name_scope(self.classifier.name):
1559
+ self.classifier.build([None, None, self.config.hidden_size])
1560
+
1561
+
1562
+ __all__ = [
1563
+ "TFAlbertPreTrainedModel",
1564
+ "TFAlbertModel",
1565
+ "TFAlbertForPreTraining",
1566
+ "TFAlbertForMaskedLM",
1567
+ "TFAlbertForSequenceClassification",
1568
+ "TFAlbertForTokenClassification",
1569
+ "TFAlbertForQuestionAnswering",
1570
+ "TFAlbertForMultipleChoice",
1571
+ "TFAlbertMainLayer",
1572
+ ]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/tokenization_albert.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for ALBERT model."""
16
+
17
+ import os
18
+ import unicodedata
19
+ from shutil import copyfile
20
+ from typing import Any, Optional
21
+
22
+ import sentencepiece as spm
23
+
24
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
25
+ from ...utils import logging
26
+ from ...utils.import_utils import requires
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
31
+
32
+
33
+ SPIECE_UNDERLINE = "▁"
34
+
35
+
36
+ @requires(backends=("sentencepiece",))
37
+ class AlbertTokenizer(PreTrainedTokenizer):
38
+ """
39
+ Construct an ALBERT tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
40
+
41
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
42
+ this superclass for more information regarding those methods.
43
+
44
+ Args:
45
+ vocab_file (`str`):
46
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
47
+ contains the vocabulary necessary to instantiate a tokenizer.
48
+ do_lower_case (`bool`, *optional*, defaults to `True`):
49
+ Whether or not to lowercase the input when tokenizing.
50
+ remove_space (`bool`, *optional*, defaults to `True`):
51
+ Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
52
+ keep_accents (`bool`, *optional*, defaults to `False`):
53
+ Whether or not to keep accents when tokenizing.
54
+ bos_token (`str`, *optional*, defaults to `"[CLS]"`):
55
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
56
+
57
+ <Tip>
58
+
59
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
60
+ sequence. The token used is the `cls_token`.
61
+
62
+ </Tip>
63
+
64
+ eos_token (`str`, *optional*, defaults to `"[SEP]"`):
65
+ The end of sequence token.
66
+
67
+ <Tip>
68
+
69
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
70
+ The token used is the `sep_token`.
71
+
72
+ </Tip>
73
+
74
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
75
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
76
+ token instead.
77
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
78
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
79
+ sequence classification or for a text and a question for question answering. It is also used as the last
80
+ token of a sequence built with special tokens.
81
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
82
+ The token used for padding, for example when batching sequences of different lengths.
83
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
84
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
85
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
86
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
87
+ The token used for masking values. This is the token used when training this model with masked language
88
+ modeling. This is the token which the model will try to predict.
89
+ sp_model_kwargs (`dict`, *optional*):
90
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
91
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
92
+ to set:
93
+
94
+ - `enable_sampling`: Enable subword regularization.
95
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
96
+
97
+ - `nbest_size = {0,1}`: No sampling is performed.
98
+ - `nbest_size > 1`: samples from the nbest_size results.
99
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
100
+ using forward-filtering-and-backward-sampling algorithm.
101
+
102
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
103
+ BPE-dropout.
104
+
105
+ Attributes:
106
+ sp_model (`SentencePieceProcessor`):
107
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
108
+ """
109
+
110
+ vocab_files_names = VOCAB_FILES_NAMES
111
+
112
+ def __init__(
113
+ self,
114
+ vocab_file,
115
+ do_lower_case=True,
116
+ remove_space=True,
117
+ keep_accents=False,
118
+ bos_token="[CLS]",
119
+ eos_token="[SEP]",
120
+ unk_token="<unk>",
121
+ sep_token="[SEP]",
122
+ pad_token="<pad>",
123
+ cls_token="[CLS]",
124
+ mask_token="[MASK]",
125
+ sp_model_kwargs: Optional[dict[str, Any]] = None,
126
+ **kwargs,
127
+ ) -> None:
128
+ # Mask token behave like a normal word, i.e. include the space before it and
129
+ # is included in the raw text, there should be a match in a non-normalized sentence.
130
+ mask_token = (
131
+ AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
132
+ if isinstance(mask_token, str)
133
+ else mask_token
134
+ )
135
+
136
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
137
+
138
+ self.do_lower_case = do_lower_case
139
+ self.remove_space = remove_space
140
+ self.keep_accents = keep_accents
141
+ self.vocab_file = vocab_file
142
+
143
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
144
+ self.sp_model.Load(vocab_file)
145
+
146
+ super().__init__(
147
+ do_lower_case=do_lower_case,
148
+ remove_space=remove_space,
149
+ keep_accents=keep_accents,
150
+ bos_token=bos_token,
151
+ eos_token=eos_token,
152
+ unk_token=unk_token,
153
+ sep_token=sep_token,
154
+ pad_token=pad_token,
155
+ cls_token=cls_token,
156
+ mask_token=mask_token,
157
+ sp_model_kwargs=self.sp_model_kwargs,
158
+ **kwargs,
159
+ )
160
+
161
+ @property
162
+ def vocab_size(self) -> int:
163
+ return len(self.sp_model)
164
+
165
+ def get_vocab(self) -> dict[str, int]:
166
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
167
+ vocab.update(self.added_tokens_encoder)
168
+ return vocab
169
+
170
+ def __getstate__(self):
171
+ state = self.__dict__.copy()
172
+ state["sp_model"] = None
173
+ return state
174
+
175
+ def __setstate__(self, d):
176
+ self.__dict__ = d
177
+
178
+ # for backward compatibility
179
+ if not hasattr(self, "sp_model_kwargs"):
180
+ self.sp_model_kwargs = {}
181
+
182
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
183
+ self.sp_model.Load(self.vocab_file)
184
+
185
+ def preprocess_text(self, inputs):
186
+ if self.remove_space:
187
+ outputs = " ".join(inputs.strip().split())
188
+ else:
189
+ outputs = inputs
190
+ outputs = outputs.replace("``", '"').replace("''", '"')
191
+
192
+ if not self.keep_accents:
193
+ outputs = unicodedata.normalize("NFKD", outputs)
194
+ outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
195
+ if self.do_lower_case:
196
+ outputs = outputs.lower()
197
+
198
+ return outputs
199
+
200
+ def _tokenize(self, text: str) -> list[str]:
201
+ """Tokenize a string."""
202
+ text = self.preprocess_text(text)
203
+ pieces = self.sp_model.encode(text, out_type=str)
204
+ new_pieces = []
205
+ for piece in pieces:
206
+ if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
207
+ # Logic to handle special cases see https://github.com/google-research/bert/blob/master/README.md#tokenization
208
+ # `9,9` -> ['▁9', ',', '9'] instead of [`_9,`, '9']
209
+ cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
210
+ if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
211
+ if len(cur_pieces[0]) == 1:
212
+ cur_pieces = cur_pieces[1:]
213
+ else:
214
+ cur_pieces[0] = cur_pieces[0][1:]
215
+ cur_pieces.append(piece[-1])
216
+ new_pieces.extend(cur_pieces)
217
+ else:
218
+ new_pieces.append(piece)
219
+
220
+ return new_pieces
221
+
222
+ def _convert_token_to_id(self, token):
223
+ """Converts a token (str) in an id using the vocab."""
224
+ return self.sp_model.PieceToId(token)
225
+
226
+ def _convert_id_to_token(self, index):
227
+ """Converts an index (integer) in a token (str) using the vocab."""
228
+ return self.sp_model.IdToPiece(index)
229
+
230
+ def convert_tokens_to_string(self, tokens):
231
+ """Converts a sequence of tokens (string) in a single string."""
232
+ current_sub_tokens = []
233
+ out_string = ""
234
+ prev_is_special = False
235
+ for token in tokens:
236
+ # make sure that special tokens are not decoded using sentencepiece model
237
+ if token in self.all_special_tokens:
238
+ if not prev_is_special:
239
+ out_string += " "
240
+ out_string += self.sp_model.decode(current_sub_tokens) + token
241
+ prev_is_special = True
242
+ current_sub_tokens = []
243
+ else:
244
+ current_sub_tokens.append(token)
245
+ prev_is_special = False
246
+ out_string += self.sp_model.decode(current_sub_tokens)
247
+ return out_string.strip()
248
+
249
+ def build_inputs_with_special_tokens(
250
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
251
+ ) -> list[int]:
252
+ """
253
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
254
+ adding special tokens. An ALBERT sequence has the following format:
255
+
256
+ - single sequence: `[CLS] X [SEP]`
257
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
258
+
259
+ Args:
260
+ token_ids_0 (`List[int]`):
261
+ List of IDs to which the special tokens will be added.
262
+ token_ids_1 (`List[int]`, *optional*):
263
+ Optional second list of IDs for sequence pairs.
264
+
265
+ Returns:
266
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
267
+ """
268
+ sep = [self.sep_token_id]
269
+ cls = [self.cls_token_id]
270
+ if token_ids_1 is None:
271
+ return cls + token_ids_0 + sep
272
+ return cls + token_ids_0 + sep + token_ids_1 + sep
273
+
274
+ def get_special_tokens_mask(
275
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
276
+ ) -> list[int]:
277
+ """
278
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
279
+ special tokens using the tokenizer `prepare_for_model` method.
280
+
281
+ Args:
282
+ token_ids_0 (`List[int]`):
283
+ List of IDs.
284
+ token_ids_1 (`List[int]`, *optional*):
285
+ Optional second list of IDs for sequence pairs.
286
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
287
+ Whether or not the token list is already formatted with special tokens for the model.
288
+
289
+ Returns:
290
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
291
+ """
292
+
293
+ if already_has_special_tokens:
294
+ return super().get_special_tokens_mask(
295
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
296
+ )
297
+
298
+ if token_ids_1 is not None:
299
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
300
+ return [1] + ([0] * len(token_ids_0)) + [1]
301
+
302
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
303
+ if not os.path.isdir(save_directory):
304
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
305
+ return
306
+ out_vocab_file = os.path.join(
307
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
308
+ )
309
+
310
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
311
+ copyfile(self.vocab_file, out_vocab_file)
312
+ elif not os.path.isfile(self.vocab_file):
313
+ with open(out_vocab_file, "wb") as fi:
314
+ content_spiece_model = self.sp_model.serialized_model_proto()
315
+ fi.write(content_spiece_model)
316
+
317
+ return (out_vocab_file,)
318
+
319
+
320
+ __all__ = ["AlbertTokenizer"]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/tokenization_albert_fast.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for ALBERT model."""
16
+
17
+ import os
18
+ from shutil import copyfile
19
+ from typing import Optional
20
+
21
+ from ...tokenization_utils import AddedToken
22
+ from ...tokenization_utils_fast import PreTrainedTokenizerFast
23
+ from ...utils import is_sentencepiece_available, logging
24
+
25
+
26
+ if is_sentencepiece_available():
27
+ from .tokenization_albert import AlbertTokenizer
28
+ else:
29
+ AlbertTokenizer = None
30
+
31
+ logger = logging.get_logger(__name__)
32
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
33
+
34
+
35
+ SPIECE_UNDERLINE = "▁"
36
+
37
+
38
+ class AlbertTokenizerFast(PreTrainedTokenizerFast):
39
+ """
40
+ Construct a "fast" ALBERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on
41
+ [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). This
42
+ tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should refer to
43
+ this superclass for more information regarding those methods
44
+
45
+ Args:
46
+ vocab_file (`str`):
47
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
48
+ contains the vocabulary necessary to instantiate a tokenizer.
49
+ do_lower_case (`bool`, *optional*, defaults to `True`):
50
+ Whether or not to lowercase the input when tokenizing.
51
+ remove_space (`bool`, *optional*, defaults to `True`):
52
+ Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
53
+ keep_accents (`bool`, *optional*, defaults to `False`):
54
+ Whether or not to keep accents when tokenizing.
55
+ bos_token (`str`, *optional*, defaults to `"[CLS]"`):
56
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
57
+
58
+ <Tip>
59
+
60
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
61
+ sequence. The token used is the `cls_token`.
62
+
63
+ </Tip>
64
+
65
+ eos_token (`str`, *optional*, defaults to `"[SEP]"`):
66
+ The end of sequence token. .. note:: When building a sequence using special tokens, this is not the token
67
+ that is used for the end of sequence. The token used is the `sep_token`.
68
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
69
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
70
+ token instead.
71
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
72
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
73
+ sequence classification or for a text and a question for question answering. It is also used as the last
74
+ token of a sequence built with special tokens.
75
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
76
+ The token used for padding, for example when batching sequences of different lengths.
77
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
78
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
79
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
80
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
81
+ The token used for masking values. This is the token used when training this model with masked language
82
+ modeling. This is the token which the model will try to predict.
83
+ """
84
+
85
+ vocab_files_names = VOCAB_FILES_NAMES
86
+ slow_tokenizer_class = AlbertTokenizer
87
+
88
+ def __init__(
89
+ self,
90
+ vocab_file=None,
91
+ tokenizer_file=None,
92
+ do_lower_case=True,
93
+ remove_space=True,
94
+ keep_accents=False,
95
+ bos_token="[CLS]",
96
+ eos_token="[SEP]",
97
+ unk_token="<unk>",
98
+ sep_token="[SEP]",
99
+ pad_token="<pad>",
100
+ cls_token="[CLS]",
101
+ mask_token="[MASK]",
102
+ **kwargs,
103
+ ):
104
+ # Mask token behave like a normal word, i.e. include the space before it and
105
+ # is included in the raw text, there should be a match in a non-normalized sentence.
106
+ mask_token = (
107
+ AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
108
+ if isinstance(mask_token, str)
109
+ else mask_token
110
+ )
111
+
112
+ super().__init__(
113
+ vocab_file,
114
+ tokenizer_file=tokenizer_file,
115
+ do_lower_case=do_lower_case,
116
+ remove_space=remove_space,
117
+ keep_accents=keep_accents,
118
+ bos_token=bos_token,
119
+ eos_token=eos_token,
120
+ unk_token=unk_token,
121
+ sep_token=sep_token,
122
+ pad_token=pad_token,
123
+ cls_token=cls_token,
124
+ mask_token=mask_token,
125
+ **kwargs,
126
+ )
127
+
128
+ self.do_lower_case = do_lower_case
129
+ self.remove_space = remove_space
130
+ self.keep_accents = keep_accents
131
+ self.vocab_file = vocab_file
132
+
133
+ def build_inputs_with_special_tokens(
134
+ self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
135
+ ) -> list[int]:
136
+ """
137
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
138
+ adding special tokens. An ALBERT sequence has the following format:
139
+
140
+ - single sequence: `[CLS] X [SEP]`
141
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
142
+
143
+ Args:
144
+ token_ids_0 (`List[int]`):
145
+ List of IDs to which the special tokens will be added
146
+ token_ids_1 (`List[int]`, *optional*):
147
+ Optional second list of IDs for sequence pairs.
148
+
149
+ Returns:
150
+ `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
151
+ """
152
+ sep = [self.sep_token_id]
153
+ cls = [self.cls_token_id]
154
+ if token_ids_1 is None:
155
+ return cls + token_ids_0 + sep
156
+ return cls + token_ids_0 + sep + token_ids_1 + sep
157
+
158
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
159
+ if not self.can_save_slow_tokenizer:
160
+ raise ValueError(
161
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
162
+ "tokenizer."
163
+ )
164
+
165
+ if not os.path.isdir(save_directory):
166
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
167
+ return
168
+ out_vocab_file = os.path.join(
169
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
170
+ )
171
+
172
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
173
+ copyfile(self.vocab_file, out_vocab_file)
174
+
175
+ return (out_vocab_file,)
176
+
177
+
178
+ __all__ = ["AlbertTokenizerFast"]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .auto_factory import *
22
+ from .configuration_auto import *
23
+ from .feature_extraction_auto import *
24
+ from .image_processing_auto import *
25
+ from .modeling_auto import *
26
+ from .modeling_flax_auto import *
27
+ from .modeling_tf_auto import *
28
+ from .processing_auto import *
29
+ from .tokenization_auto import *
30
+ from .video_processing_auto import *
31
+ else:
32
+ import sys
33
+
34
+ _file = globals()["__file__"]
35
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py ADDED
@@ -0,0 +1,882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Factory function to build auto-model classes."""
16
+
17
+ import copy
18
+ import importlib
19
+ import json
20
+ import os
21
+ import warnings
22
+ from collections import OrderedDict
23
+ from collections.abc import Iterator
24
+ from typing import Any, TypeVar, Union
25
+
26
+ from ...configuration_utils import PretrainedConfig
27
+ from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
28
+ from ...utils import (
29
+ CONFIG_NAME,
30
+ cached_file,
31
+ copy_func,
32
+ extract_commit_hash,
33
+ find_adapter_config_file,
34
+ is_peft_available,
35
+ is_torch_available,
36
+ logging,
37
+ requires_backends,
38
+ )
39
+ from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
40
+
41
+
42
+ if is_torch_available():
43
+ from ...generation import GenerationMixin
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ _T = TypeVar("_T")
49
+ # Tokenizers will depend on packages installed, too much variance and there are no common base or Protocol
50
+ _LazyAutoMappingValue = tuple[Union[type[Any], None], Union[type[Any], None]]
51
+
52
+ CLASS_DOCSTRING = """
53
+ This is a generic model class that will be instantiated as one of the model classes of the library when created
54
+ with the [`~BaseAutoModelClass.from_pretrained`] class method or the [`~BaseAutoModelClass.from_config`] class
55
+ method.
56
+
57
+ This class cannot be instantiated directly using `__init__()` (throws an error).
58
+ """
59
+
60
+ FROM_CONFIG_DOCSTRING = """
61
+ Instantiates one of the model classes of the library from a configuration.
62
+
63
+ Note:
64
+ Loading a model from its configuration file does **not** load the model weights. It only affects the
65
+ model's configuration. Use [`~BaseAutoModelClass.from_pretrained`] to load the model weights.
66
+
67
+ Args:
68
+ config ([`PretrainedConfig`]):
69
+ The model class to instantiate is selected based on the configuration class:
70
+
71
+ List options
72
+ attn_implementation (`str`, *optional*):
73
+ The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
74
+
75
+ Examples:
76
+
77
+ ```python
78
+ >>> from transformers import AutoConfig, BaseAutoModelClass
79
+
80
+ >>> # Download configuration from huggingface.co and cache.
81
+ >>> config = AutoConfig.from_pretrained("checkpoint_placeholder")
82
+ >>> model = BaseAutoModelClass.from_config(config)
83
+ ```
84
+ """
85
+
86
+ FROM_PRETRAINED_TORCH_DOCSTRING = """
87
+ Instantiate one of the model classes of the library from a pretrained model.
88
+
89
+ The model class to instantiate is selected based on the `model_type` property of the config object (either
90
+ passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
91
+ falling back to using pattern matching on `pretrained_model_name_or_path`:
92
+
93
+ List options
94
+
95
+ The model is set in evaluation mode by default using `model.eval()` (so for instance, dropout modules are
96
+ deactivated). To train the model, you should first set it back in training mode with `model.train()`
97
+
98
+ Args:
99
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
100
+ Can be either:
101
+
102
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
103
+ - A path to a *directory* containing model weights saved using
104
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
105
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
106
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
107
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
108
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
109
+ model_args (additional positional arguments, *optional*):
110
+ Will be passed along to the underlying model `__init__()` method.
111
+ config ([`PretrainedConfig`], *optional*):
112
+ Configuration for the model to use instead of an automatically loaded configuration. Configuration can
113
+ be automatically loaded when:
114
+
115
+ - The model is a model provided by the library (loaded with the *model id* string of a pretrained
116
+ model).
117
+ - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
118
+ save directory.
119
+ - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
120
+ configuration JSON file named *config.json* is found in the directory.
121
+ state_dict (*dict[str, torch.Tensor]*, *optional*):
122
+ A state dictionary to use instead of a state dictionary loaded from saved weights file.
123
+
124
+ This option can be used if you want to create a model from a pretrained configuration but load your own
125
+ weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and
126
+ [`~PreTrainedModel.from_pretrained`] is not a simpler option.
127
+ cache_dir (`str` or `os.PathLike`, *optional*):
128
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
129
+ standard cache should not be used.
130
+ from_tf (`bool`, *optional*, defaults to `False`):
131
+ Load the model weights from a TensorFlow checkpoint save file (see docstring of
132
+ `pretrained_model_name_or_path` argument).
133
+ force_download (`bool`, *optional*, defaults to `False`):
134
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
135
+ cached versions if they exist.
136
+ resume_download:
137
+ Deprecated and ignored. All downloads are now resumed by default when possible.
138
+ Will be removed in v5 of Transformers.
139
+ proxies (`dict[str, str]`, *optional*):
140
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
141
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
142
+ output_loading_info(`bool`, *optional*, defaults to `False`):
143
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
144
+ local_files_only(`bool`, *optional*, defaults to `False`):
145
+ Whether or not to only look at local files (e.g., not try downloading the model).
146
+ revision (`str`, *optional*, defaults to `"main"`):
147
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
148
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
149
+ identifier allowed by git.
150
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
151
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
152
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
153
+ execute code present on the Hub on your local machine.
154
+ code_revision (`str`, *optional*, defaults to `"main"`):
155
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than
156
+ the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
157
+ system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
158
+ allowed by git.
159
+ kwargs (additional keyword arguments, *optional*):
160
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
161
+ `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
162
+ automatically loaded:
163
+
164
+ - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
165
+ underlying model's `__init__` method (we assume all relevant updates to the configuration have
166
+ already been done)
167
+ - If a configuration is not provided, `kwargs` will be first passed to the configuration class
168
+ initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
169
+ corresponds to a configuration attribute will be used to override said attribute with the
170
+ supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
171
+ will be passed to the underlying model's `__init__` function.
172
+
173
+ Examples:
174
+
175
+ ```python
176
+ >>> from transformers import AutoConfig, BaseAutoModelClass
177
+
178
+ >>> # Download model and configuration from huggingface.co and cache.
179
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
180
+
181
+ >>> # Update configuration during loading
182
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
183
+ >>> model.config.output_attentions
184
+ True
185
+
186
+ >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
187
+ >>> config = AutoConfig.from_pretrained("./tf_model/shortcut_placeholder_tf_model_config.json")
188
+ >>> model = BaseAutoModelClass.from_pretrained(
189
+ ... "./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index", from_tf=True, config=config
190
+ ... )
191
+ ```
192
+ """
193
+
194
+ FROM_PRETRAINED_TF_DOCSTRING = """
195
+ Instantiate one of the model classes of the library from a pretrained model.
196
+
197
+ The model class to instantiate is selected based on the `model_type` property of the config object (either
198
+ passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
199
+ falling back to using pattern matching on `pretrained_model_name_or_path`:
200
+
201
+ List options
202
+
203
+ Args:
204
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
205
+ Can be either:
206
+
207
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
208
+ - A path to a *directory* containing model weights saved using
209
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
210
+ - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
211
+ case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
212
+ argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
213
+ using the provided conversion scripts and loading the TensorFlow model afterwards.
214
+ model_args (additional positional arguments, *optional*):
215
+ Will be passed along to the underlying model `__init__()` method.
216
+ config ([`PretrainedConfig`], *optional*):
217
+ Configuration for the model to use instead of an automatically loaded configuration. Configuration can
218
+ be automatically loaded when:
219
+
220
+ - The model is a model provided by the library (loaded with the *model id* string of a pretrained
221
+ model).
222
+ - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
223
+ save directory.
224
+ - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
225
+ configuration JSON file named *config.json* is found in the directory.
226
+ cache_dir (`str` or `os.PathLike`, *optional*):
227
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
228
+ standard cache should not be used.
229
+ from_pt (`bool`, *optional*, defaults to `False`):
230
+ Load the model weights from a PyTorch checkpoint save file (see docstring of
231
+ `pretrained_model_name_or_path` argument).
232
+ force_download (`bool`, *optional*, defaults to `False`):
233
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
234
+ cached versions if they exist.
235
+ resume_download:
236
+ Deprecated and ignored. All downloads are now resumed by default when possible.
237
+ Will be removed in v5 of Transformers.
238
+ proxies (`dict[str, str]`, *optional*):
239
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
240
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
241
+ output_loading_info(`bool`, *optional*, defaults to `False`):
242
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
243
+ local_files_only(`bool`, *optional*, defaults to `False`):
244
+ Whether or not to only look at local files (e.g., not try downloading the model).
245
+ revision (`str`, *optional*, defaults to `"main"`):
246
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
247
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
248
+ identifier allowed by git.
249
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
250
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
251
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
252
+ execute code present on the Hub on your local machine.
253
+ code_revision (`str`, *optional*, defaults to `"main"`):
254
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than
255
+ the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
256
+ system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
257
+ allowed by git.
258
+ kwargs (additional keyword arguments, *optional*):
259
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
260
+ `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
261
+ automatically loaded:
262
+
263
+ - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
264
+ underlying model's `__init__` method (we assume all relevant updates to the configuration have
265
+ already been done)
266
+ - If a configuration is not provided, `kwargs` will be first passed to the configuration class
267
+ initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
268
+ corresponds to a configuration attribute will be used to override said attribute with the
269
+ supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
270
+ will be passed to the underlying model's `__init__` function.
271
+
272
+ Examples:
273
+
274
+ ```python
275
+ >>> from transformers import AutoConfig, BaseAutoModelClass
276
+
277
+ >>> # Download model and configuration from huggingface.co and cache.
278
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
279
+
280
+ >>> # Update configuration during loading
281
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
282
+ >>> model.config.output_attentions
283
+ True
284
+
285
+ >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
286
+ >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json")
287
+ >>> model = BaseAutoModelClass.from_pretrained(
288
+ ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config
289
+ ... )
290
+ ```
291
+ """
292
+
293
+ FROM_PRETRAINED_FLAX_DOCSTRING = """
294
+ Instantiate one of the model classes of the library from a pretrained model.
295
+
296
+ The model class to instantiate is selected based on the `model_type` property of the config object (either
297
+ passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
298
+ falling back to using pattern matching on `pretrained_model_name_or_path`:
299
+
300
+ List options
301
+
302
+ Args:
303
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
304
+ Can be either:
305
+
306
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
307
+ - A path to a *directory* containing model weights saved using
308
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
309
+ - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
310
+ case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
311
+ argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
312
+ using the provided conversion scripts and loading the TensorFlow model afterwards.
313
+ model_args (additional positional arguments, *optional*):
314
+ Will be passed along to the underlying model `__init__()` method.
315
+ config ([`PretrainedConfig`], *optional*):
316
+ Configuration for the model to use instead of an automatically loaded configuration. Configuration can
317
+ be automatically loaded when:
318
+
319
+ - The model is a model provided by the library (loaded with the *model id* string of a pretrained
320
+ model).
321
+ - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
322
+ save directory.
323
+ - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
324
+ configuration JSON file named *config.json* is found in the directory.
325
+ cache_dir (`str` or `os.PathLike`, *optional*):
326
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
327
+ standard cache should not be used.
328
+ from_pt (`bool`, *optional*, defaults to `False`):
329
+ Load the model weights from a PyTorch checkpoint save file (see docstring of
330
+ `pretrained_model_name_or_path` argument).
331
+ force_download (`bool`, *optional*, defaults to `False`):
332
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
333
+ cached versions if they exist.
334
+ resume_download:
335
+ Deprecated and ignored. All downloads are now resumed by default when possible.
336
+ Will be removed in v5 of Transformers.
337
+ proxies (`dict[str, str]`, *optional*):
338
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
339
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
340
+ output_loading_info(`bool`, *optional*, defaults to `False`):
341
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
342
+ local_files_only(`bool`, *optional*, defaults to `False`):
343
+ Whether or not to only look at local files (e.g., not try downloading the model).
344
+ revision (`str`, *optional*, defaults to `"main"`):
345
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
346
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
347
+ identifier allowed by git.
348
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
349
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
350
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
351
+ execute code present on the Hub on your local machine.
352
+ code_revision (`str`, *optional*, defaults to `"main"`):
353
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than
354
+ the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
355
+ system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
356
+ allowed by git.
357
+ kwargs (additional keyword arguments, *optional*):
358
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
359
+ `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
360
+ automatically loaded:
361
+
362
+ - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
363
+ underlying model's `__init__` method (we assume all relevant updates to the configuration have
364
+ already been done)
365
+ - If a configuration is not provided, `kwargs` will be first passed to the configuration class
366
+ initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
367
+ corresponds to a configuration attribute will be used to override said attribute with the
368
+ supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
369
+ will be passed to the underlying model's `__init__` function.
370
+
371
+ Examples:
372
+
373
+ ```python
374
+ >>> from transformers import AutoConfig, BaseAutoModelClass
375
+
376
+ >>> # Download model and configuration from huggingface.co and cache.
377
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
378
+
379
+ >>> # Update configuration during loading
380
+ >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
381
+ >>> model.config.output_attentions
382
+ True
383
+
384
+ >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
385
+ >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json")
386
+ >>> model = BaseAutoModelClass.from_pretrained(
387
+ ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config
388
+ ... )
389
+ ```
390
+ """
391
+
392
+
393
+ def _get_model_class(config, model_mapping):
394
+ supported_models = model_mapping[type(config)]
395
+ if not isinstance(supported_models, (list, tuple)):
396
+ return supported_models
397
+
398
+ name_to_model = {model.__name__: model for model in supported_models}
399
+ architectures = getattr(config, "architectures", [])
400
+ for arch in architectures:
401
+ if arch in name_to_model:
402
+ return name_to_model[arch]
403
+ elif f"TF{arch}" in name_to_model:
404
+ return name_to_model[f"TF{arch}"]
405
+ elif f"Flax{arch}" in name_to_model:
406
+ return name_to_model[f"Flax{arch}"]
407
+
408
+ # If not architecture is set in the config or match the supported models, the first element of the tuple is the
409
+ # defaults.
410
+ return supported_models[0]
411
+
412
+
413
+ class _BaseAutoModelClass:
414
+ # Base class for auto models.
415
+ _model_mapping = None
416
+
417
+ def __init__(self, *args, **kwargs) -> None:
418
+ raise OSError(
419
+ f"{self.__class__.__name__} is designed to be instantiated "
420
+ f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
421
+ f"`{self.__class__.__name__}.from_config(config)` methods."
422
+ )
423
+
424
+ @classmethod
425
+ def from_config(cls, config, **kwargs):
426
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
427
+ has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
428
+ has_local_code = type(config) in cls._model_mapping
429
+ if has_remote_code:
430
+ class_ref = config.auto_map[cls.__name__]
431
+ if "--" in class_ref:
432
+ upstream_repo = class_ref.split("--")[0]
433
+ else:
434
+ upstream_repo = None
435
+ trust_remote_code = resolve_trust_remote_code(
436
+ trust_remote_code, config._name_or_path, has_local_code, has_remote_code, upstream_repo=upstream_repo
437
+ )
438
+
439
+ if has_remote_code and trust_remote_code:
440
+ if "--" in class_ref:
441
+ repo_id, class_ref = class_ref.split("--")
442
+ else:
443
+ repo_id = config.name_or_path
444
+ model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
445
+ # This block handles the case where the user is loading a model with `trust_remote_code=True`
446
+ # but a library model exists with the same name. We don't want to override the autoclass
447
+ # mappings in this case, or all future loads of that model will be the remote code model.
448
+ if not has_local_code:
449
+ cls.register(config.__class__, model_class, exist_ok=True)
450
+ model_class.register_for_auto_class(auto_class=cls)
451
+ _ = kwargs.pop("code_revision", None)
452
+ model_class = add_generation_mixin_to_remote_model(model_class)
453
+ return model_class._from_config(config, **kwargs)
454
+ elif type(config) in cls._model_mapping:
455
+ model_class = _get_model_class(config, cls._model_mapping)
456
+ return model_class._from_config(config, **kwargs)
457
+
458
+ raise ValueError(
459
+ f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
460
+ f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping)}."
461
+ )
462
+
463
+ @classmethod
464
+ def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
465
+ """Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses."""
466
+ return config
467
+
468
+ @classmethod
469
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], *model_args, **kwargs):
470
+ config = kwargs.pop("config", None)
471
+ trust_remote_code = kwargs.get("trust_remote_code")
472
+ kwargs["_from_auto"] = True
473
+ hub_kwargs_names = [
474
+ "cache_dir",
475
+ "force_download",
476
+ "local_files_only",
477
+ "proxies",
478
+ "resume_download",
479
+ "revision",
480
+ "subfolder",
481
+ "use_auth_token",
482
+ "token",
483
+ ]
484
+ hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
485
+ code_revision = kwargs.pop("code_revision", None)
486
+ commit_hash = kwargs.pop("_commit_hash", None)
487
+ adapter_kwargs = kwargs.pop("adapter_kwargs", None)
488
+
489
+ token = hub_kwargs.pop("token", None)
490
+ use_auth_token = hub_kwargs.pop("use_auth_token", None)
491
+ if use_auth_token is not None:
492
+ warnings.warn(
493
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
494
+ FutureWarning,
495
+ )
496
+ if token is not None:
497
+ raise ValueError(
498
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
499
+ )
500
+ token = use_auth_token
501
+
502
+ if token is not None:
503
+ hub_kwargs["token"] = token
504
+
505
+ if commit_hash is None:
506
+ if not isinstance(config, PretrainedConfig):
507
+ # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
508
+ resolved_config_file = cached_file(
509
+ pretrained_model_name_or_path,
510
+ CONFIG_NAME,
511
+ _raise_exceptions_for_gated_repo=False,
512
+ _raise_exceptions_for_missing_entries=False,
513
+ _raise_exceptions_for_connection_errors=False,
514
+ **hub_kwargs,
515
+ )
516
+ commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
517
+ else:
518
+ commit_hash = getattr(config, "_commit_hash", None)
519
+
520
+ if is_peft_available():
521
+ if adapter_kwargs is None:
522
+ adapter_kwargs = {}
523
+ if token is not None:
524
+ adapter_kwargs["token"] = token
525
+
526
+ maybe_adapter_path = find_adapter_config_file(
527
+ pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs
528
+ )
529
+
530
+ if maybe_adapter_path is not None:
531
+ with open(maybe_adapter_path, "r", encoding="utf-8") as f:
532
+ adapter_config = json.load(f)
533
+
534
+ adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path
535
+ pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]
536
+
537
+ if not isinstance(config, PretrainedConfig):
538
+ kwargs_orig = copy.deepcopy(kwargs)
539
+ # ensure not to pollute the config object with dtype="auto" - since it's
540
+ # meaningless in the context of the config object - torch.dtype values are acceptable
541
+ if kwargs.get("torch_dtype") == "auto":
542
+ _ = kwargs.pop("torch_dtype")
543
+ if kwargs.get("dtype") == "auto":
544
+ _ = kwargs.pop("dtype")
545
+ # to not overwrite the quantization_config if config has a quantization_config
546
+ if kwargs.get("quantization_config") is not None:
547
+ _ = kwargs.pop("quantization_config")
548
+
549
+ config, kwargs = AutoConfig.from_pretrained(
550
+ pretrained_model_name_or_path,
551
+ return_unused_kwargs=True,
552
+ code_revision=code_revision,
553
+ _commit_hash=commit_hash,
554
+ **hub_kwargs,
555
+ **kwargs,
556
+ )
557
+
558
+ # if torch_dtype=auto was passed here, ensure to pass it on
559
+ if kwargs_orig.get("torch_dtype", None) == "auto":
560
+ kwargs["torch_dtype"] = "auto"
561
+ if kwargs_orig.get("dtype", None) == "auto":
562
+ kwargs["dtype"] = "auto"
563
+ if kwargs_orig.get("quantization_config", None) is not None:
564
+ kwargs["quantization_config"] = kwargs_orig["quantization_config"]
565
+
566
+ has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
567
+ has_local_code = type(config) in cls._model_mapping
568
+ upstream_repo = None
569
+ if has_remote_code:
570
+ class_ref = config.auto_map[cls.__name__]
571
+ if "--" in class_ref:
572
+ upstream_repo = class_ref.split("--")[0]
573
+ trust_remote_code = resolve_trust_remote_code(
574
+ trust_remote_code,
575
+ pretrained_model_name_or_path,
576
+ has_local_code,
577
+ has_remote_code,
578
+ upstream_repo=upstream_repo,
579
+ )
580
+ kwargs["trust_remote_code"] = trust_remote_code
581
+
582
+ # Set the adapter kwargs
583
+ kwargs["adapter_kwargs"] = adapter_kwargs
584
+
585
+ if has_remote_code and trust_remote_code:
586
+ model_class = get_class_from_dynamic_module(
587
+ class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs
588
+ )
589
+ _ = hub_kwargs.pop("code_revision", None)
590
+ # This block handles the case where the user is loading a model with `trust_remote_code=True`
591
+ # but a library model exists with the same name. We don't want to override the autoclass
592
+ # mappings in this case, or all future loads of that model will be the remote code model.
593
+ if not has_local_code:
594
+ cls.register(config.__class__, model_class, exist_ok=True)
595
+ model_class.register_for_auto_class(auto_class=cls)
596
+ model_class = add_generation_mixin_to_remote_model(model_class)
597
+ return model_class.from_pretrained(
598
+ pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
599
+ )
600
+ elif type(config) in cls._model_mapping:
601
+ model_class = _get_model_class(config, cls._model_mapping)
602
+ if model_class.config_class == config.sub_configs.get("text_config", None):
603
+ config = config.get_text_config()
604
+ return model_class.from_pretrained(
605
+ pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
606
+ )
607
+ raise ValueError(
608
+ f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
609
+ f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping)}."
610
+ )
611
+
612
+ @classmethod
613
+ def register(cls, config_class, model_class, exist_ok=False) -> None:
614
+ """
615
+ Register a new model for this class.
616
+
617
+ Args:
618
+ config_class ([`PretrainedConfig`]):
619
+ The configuration corresponding to the model to register.
620
+ model_class ([`PreTrainedModel`]):
621
+ The model to register.
622
+ """
623
+ if hasattr(model_class, "config_class") and model_class.config_class.__name__ != config_class.__name__:
624
+ raise ValueError(
625
+ "The model class you are passing has a `config_class` attribute that is not consistent with the "
626
+ f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix "
627
+ "one of those so they match!"
628
+ )
629
+ cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok)
630
+
631
+
632
+ class _BaseAutoBackboneClass(_BaseAutoModelClass):
633
+ # Base class for auto backbone models.
634
+ _model_mapping = None
635
+
636
+ @classmethod
637
+ def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
638
+ requires_backends(cls, ["vision", "timm"])
639
+ from ...models.timm_backbone import TimmBackboneConfig
640
+
641
+ config = kwargs.pop("config", TimmBackboneConfig())
642
+
643
+ if kwargs.get("out_features") is not None:
644
+ raise ValueError("Cannot specify `out_features` for timm backbones")
645
+
646
+ if kwargs.get("output_loading_info", False):
647
+ raise ValueError("Cannot specify `output_loading_info=True` when loading from timm")
648
+
649
+ num_channels = kwargs.pop("num_channels", config.num_channels)
650
+ features_only = kwargs.pop("features_only", config.features_only)
651
+ use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
652
+ out_indices = kwargs.pop("out_indices", config.out_indices)
653
+ config = TimmBackboneConfig(
654
+ backbone=pretrained_model_name_or_path,
655
+ num_channels=num_channels,
656
+ features_only=features_only,
657
+ use_pretrained_backbone=use_pretrained_backbone,
658
+ out_indices=out_indices,
659
+ )
660
+ return super().from_config(config, **kwargs)
661
+
662
+ @classmethod
663
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
664
+ use_timm_backbone = kwargs.pop("use_timm_backbone", False)
665
+ if use_timm_backbone:
666
+ return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
667
+
668
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
669
+
670
+
671
+ def insert_head_doc(docstring, head_doc: str = ""):
672
+ if len(head_doc) > 0:
673
+ return docstring.replace(
674
+ "one of the model classes of the library ",
675
+ f"one of the model classes of the library (with a {head_doc} head) ",
676
+ )
677
+ return docstring.replace(
678
+ "one of the model classes of the library ", "one of the base model classes of the library "
679
+ )
680
+
681
+
682
+ def auto_class_update(cls, checkpoint_for_example: str = "google-bert/bert-base-cased", head_doc: str = ""):
683
+ # Create a new class with the right name from the base class
684
+ model_mapping = cls._model_mapping
685
+ name = cls.__name__
686
+ class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
687
+ cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
688
+
689
+ # Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't
690
+ # have a specific docstrings for them.
691
+ from_config = copy_func(_BaseAutoModelClass.from_config)
692
+ from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc)
693
+ from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
694
+ from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
695
+ from_config.__doc__ = from_config_docstring
696
+ from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
697
+ cls.from_config = classmethod(from_config)
698
+
699
+ if name.startswith("TF"):
700
+ from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
701
+ elif name.startswith("Flax"):
702
+ from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING
703
+ else:
704
+ from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING
705
+ from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained)
706
+ from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc)
707
+ from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name)
708
+ from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
709
+ shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
710
+ from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
711
+ from_pretrained.__doc__ = from_pretrained_docstring
712
+ from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
713
+ cls.from_pretrained = classmethod(from_pretrained)
714
+ return cls
715
+
716
+
717
+ def get_values(model_mapping):
718
+ result = []
719
+ for model in model_mapping.values():
720
+ if isinstance(model, (list, tuple)):
721
+ result += list(model)
722
+ else:
723
+ result.append(model)
724
+
725
+ return result
726
+
727
+
728
+ def getattribute_from_module(module, attr):
729
+ if attr is None:
730
+ return None
731
+ if isinstance(attr, tuple):
732
+ return tuple(getattribute_from_module(module, a) for a in attr)
733
+ if hasattr(module, attr):
734
+ return getattr(module, attr)
735
+ # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
736
+ # object at the top level.
737
+ transformers_module = importlib.import_module("transformers")
738
+
739
+ if module != transformers_module:
740
+ try:
741
+ return getattribute_from_module(transformers_module, attr)
742
+ except ValueError:
743
+ raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!")
744
+ else:
745
+ raise ValueError(f"Could not find {attr} in {transformers_module}!")
746
+
747
+
748
+ def add_generation_mixin_to_remote_model(model_class):
749
+ """
750
+ Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model.
751
+
752
+ This function is used for backwards compatibility purposes: in v4.45, we've started a deprecation cycle to make
753
+ `PreTrainedModel` stop inheriting from `GenerationMixin`. Without this function, older models dynamically loaded
754
+ from the Hub may not have the `generate` method after we remove the inheritance.
755
+ """
756
+ # 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing
757
+ if "torch.nn.modules.module.Module" not in str(model_class.__mro__):
758
+ return model_class
759
+
760
+ # 2. If it already **directly** inherits from GenerationMixin, do nothing
761
+ if "GenerationMixin" in str(model_class.__bases__):
762
+ return model_class
763
+
764
+ # 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
765
+ # `prepare_inputs_for_generation` method.
766
+ has_custom_generate_in_class = hasattr(model_class, "generate") and "GenerationMixin" not in str(
767
+ getattr(model_class, "generate")
768
+ )
769
+ has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str(
770
+ getattr(model_class, "prepare_inputs_for_generation")
771
+ )
772
+ if has_custom_generate_in_class or has_custom_prepare_inputs:
773
+ model_class_with_generation_mixin = type(
774
+ model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}
775
+ )
776
+ return model_class_with_generation_mixin
777
+ return model_class
778
+
779
+
780
+ class _LazyAutoMapping(OrderedDict[type[PretrainedConfig], _LazyAutoMappingValue]):
781
+ """
782
+ " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
783
+
784
+ Args:
785
+ - config_mapping: The map model type to config class
786
+ - model_mapping: The map model type to model (or tokenizer) class
787
+ """
788
+
789
+ def __init__(self, config_mapping, model_mapping) -> None:
790
+ self._config_mapping = config_mapping
791
+ self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
792
+ self._model_mapping = model_mapping
793
+ self._model_mapping._model_mapping = self
794
+ self._extra_content = {}
795
+ self._modules = {}
796
+
797
+ def __len__(self) -> int:
798
+ common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys())
799
+ return len(common_keys) + len(self._extra_content)
800
+
801
+ def __getitem__(self, key: type[PretrainedConfig]) -> _LazyAutoMappingValue:
802
+ if key in self._extra_content:
803
+ return self._extra_content[key]
804
+ model_type = self._reverse_config_mapping[key.__name__]
805
+ if model_type in self._model_mapping:
806
+ model_name = self._model_mapping[model_type]
807
+ return self._load_attr_from_module(model_type, model_name)
808
+
809
+ # Maybe there was several model types associated with this config.
810
+ model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
811
+ for mtype in model_types:
812
+ if mtype in self._model_mapping:
813
+ model_name = self._model_mapping[mtype]
814
+ return self._load_attr_from_module(mtype, model_name)
815
+ raise KeyError(key)
816
+
817
+ def _load_attr_from_module(self, model_type, attr):
818
+ module_name = model_type_to_module_name(model_type)
819
+ if module_name not in self._modules:
820
+ self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
821
+ return getattribute_from_module(self._modules[module_name], attr)
822
+
823
+ def keys(self) -> list[type[PretrainedConfig]]:
824
+ mapping_keys = [
825
+ self._load_attr_from_module(key, name)
826
+ for key, name in self._config_mapping.items()
827
+ if key in self._model_mapping
828
+ ]
829
+ return mapping_keys + list(self._extra_content.keys())
830
+
831
+ def get(self, key: type[PretrainedConfig], default: _T) -> Union[_LazyAutoMappingValue, _T]:
832
+ try:
833
+ return self.__getitem__(key)
834
+ except KeyError:
835
+ return default
836
+
837
+ def __bool__(self) -> bool:
838
+ return bool(self.keys())
839
+
840
+ def values(self) -> list[_LazyAutoMappingValue]:
841
+ mapping_values = [
842
+ self._load_attr_from_module(key, name)
843
+ for key, name in self._model_mapping.items()
844
+ if key in self._config_mapping
845
+ ]
846
+ return mapping_values + list(self._extra_content.values())
847
+
848
+ def items(self) -> list[tuple[type[PretrainedConfig], _LazyAutoMappingValue]]:
849
+ mapping_items = [
850
+ (
851
+ self._load_attr_from_module(key, self._config_mapping[key]),
852
+ self._load_attr_from_module(key, self._model_mapping[key]),
853
+ )
854
+ for key in self._model_mapping
855
+ if key in self._config_mapping
856
+ ]
857
+ return mapping_items + list(self._extra_content.items())
858
+
859
+ def __iter__(self) -> Iterator[type[PretrainedConfig]]:
860
+ return iter(self.keys())
861
+
862
+ def __contains__(self, item: type) -> bool:
863
+ if item in self._extra_content:
864
+ return True
865
+ if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
866
+ return False
867
+ model_type = self._reverse_config_mapping[item.__name__]
868
+ return model_type in self._model_mapping
869
+
870
+ def register(self, key: type[PretrainedConfig], value: _LazyAutoMappingValue, exist_ok=False) -> None:
871
+ """
872
+ Register a new model in this mapping.
873
+ """
874
+ if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
875
+ model_type = self._reverse_config_mapping[key.__name__]
876
+ if model_type in self._model_mapping and not exist_ok:
877
+ raise ValueError(f"'{key}' is already used by a Transformers model.")
878
+
879
+ self._extra_content[key] = value
880
+
881
+
882
+ __all__ = ["get_values"]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/configuration_auto.py ADDED
@@ -0,0 +1,1404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Auto Config class."""
16
+
17
+ import importlib
18
+ import os
19
+ import re
20
+ import warnings
21
+ from collections import OrderedDict
22
+ from collections.abc import Callable, Iterator, KeysView, ValuesView
23
+ from typing import Any, TypeVar, Union
24
+
25
+ from ...configuration_utils import PretrainedConfig
26
+ from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
27
+ from ...utils import CONFIG_NAME, logging
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ _CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
34
+
35
+
36
+ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
37
+ [
38
+ # Add configs here
39
+ ("aimv2", "Aimv2Config"),
40
+ ("aimv2_vision_model", "Aimv2VisionConfig"),
41
+ ("albert", "AlbertConfig"),
42
+ ("align", "AlignConfig"),
43
+ ("altclip", "AltCLIPConfig"),
44
+ ("apertus", "ApertusConfig"),
45
+ ("arcee", "ArceeConfig"),
46
+ ("aria", "AriaConfig"),
47
+ ("aria_text", "AriaTextConfig"),
48
+ ("audio-spectrogram-transformer", "ASTConfig"),
49
+ ("autoformer", "AutoformerConfig"),
50
+ ("aya_vision", "AyaVisionConfig"),
51
+ ("bamba", "BambaConfig"),
52
+ ("bark", "BarkConfig"),
53
+ ("bart", "BartConfig"),
54
+ ("beit", "BeitConfig"),
55
+ ("bert", "BertConfig"),
56
+ ("bert-generation", "BertGenerationConfig"),
57
+ ("big_bird", "BigBirdConfig"),
58
+ ("bigbird_pegasus", "BigBirdPegasusConfig"),
59
+ ("biogpt", "BioGptConfig"),
60
+ ("bit", "BitConfig"),
61
+ ("bitnet", "BitNetConfig"),
62
+ ("blenderbot", "BlenderbotConfig"),
63
+ ("blenderbot-small", "BlenderbotSmallConfig"),
64
+ ("blip", "BlipConfig"),
65
+ ("blip-2", "Blip2Config"),
66
+ ("blip_2_qformer", "Blip2QFormerConfig"),
67
+ ("bloom", "BloomConfig"),
68
+ ("blt", "BltConfig"),
69
+ ("bridgetower", "BridgeTowerConfig"),
70
+ ("bros", "BrosConfig"),
71
+ ("camembert", "CamembertConfig"),
72
+ ("canine", "CanineConfig"),
73
+ ("chameleon", "ChameleonConfig"),
74
+ ("chinese_clip", "ChineseCLIPConfig"),
75
+ ("chinese_clip_vision_model", "ChineseCLIPVisionConfig"),
76
+ ("clap", "ClapConfig"),
77
+ ("clip", "CLIPConfig"),
78
+ ("clip_text_model", "CLIPTextConfig"),
79
+ ("clip_vision_model", "CLIPVisionConfig"),
80
+ ("clipseg", "CLIPSegConfig"),
81
+ ("clvp", "ClvpConfig"),
82
+ ("code_llama", "LlamaConfig"),
83
+ ("codegen", "CodeGenConfig"),
84
+ ("cohere", "CohereConfig"),
85
+ ("cohere2", "Cohere2Config"),
86
+ ("cohere2_vision", "Cohere2VisionConfig"),
87
+ ("colpali", "ColPaliConfig"),
88
+ ("colqwen2", "ColQwen2Config"),
89
+ ("conditional_detr", "ConditionalDetrConfig"),
90
+ ("convbert", "ConvBertConfig"),
91
+ ("convnext", "ConvNextConfig"),
92
+ ("convnextv2", "ConvNextV2Config"),
93
+ ("cpmant", "CpmAntConfig"),
94
+ ("csm", "CsmConfig"),
95
+ ("ctrl", "CTRLConfig"),
96
+ ("cvt", "CvtConfig"),
97
+ ("d_fine", "DFineConfig"),
98
+ ("dab-detr", "DabDetrConfig"),
99
+ ("dac", "DacConfig"),
100
+ ("data2vec-audio", "Data2VecAudioConfig"),
101
+ ("data2vec-text", "Data2VecTextConfig"),
102
+ ("data2vec-vision", "Data2VecVisionConfig"),
103
+ ("dbrx", "DbrxConfig"),
104
+ ("deberta", "DebertaConfig"),
105
+ ("deberta-v2", "DebertaV2Config"),
106
+ ("decision_transformer", "DecisionTransformerConfig"),
107
+ ("deepseek_v2", "DeepseekV2Config"),
108
+ ("deepseek_v3", "DeepseekV3Config"),
109
+ ("deepseek_vl", "DeepseekVLConfig"),
110
+ ("deepseek_vl_hybrid", "DeepseekVLHybridConfig"),
111
+ ("deformable_detr", "DeformableDetrConfig"),
112
+ ("deit", "DeiTConfig"),
113
+ ("depth_anything", "DepthAnythingConfig"),
114
+ ("depth_pro", "DepthProConfig"),
115
+ ("deta", "DetaConfig"),
116
+ ("detr", "DetrConfig"),
117
+ ("dia", "DiaConfig"),
118
+ ("diffllama", "DiffLlamaConfig"),
119
+ ("dinat", "DinatConfig"),
120
+ ("dinov2", "Dinov2Config"),
121
+ ("dinov2_with_registers", "Dinov2WithRegistersConfig"),
122
+ ("dinov3_convnext", "DINOv3ConvNextConfig"),
123
+ ("dinov3_vit", "DINOv3ViTConfig"),
124
+ ("distilbert", "DistilBertConfig"),
125
+ ("doge", "DogeConfig"),
126
+ ("donut-swin", "DonutSwinConfig"),
127
+ ("dots1", "Dots1Config"),
128
+ ("dpr", "DPRConfig"),
129
+ ("dpt", "DPTConfig"),
130
+ ("edgetam", "EdgeTamConfig"),
131
+ ("edgetam_video", "EdgeTamVideoConfig"),
132
+ ("edgetam_vision_model", "EdgeTamVisionConfig"),
133
+ ("efficientformer", "EfficientFormerConfig"),
134
+ ("efficientloftr", "EfficientLoFTRConfig"),
135
+ ("efficientnet", "EfficientNetConfig"),
136
+ ("electra", "ElectraConfig"),
137
+ ("emu3", "Emu3Config"),
138
+ ("encodec", "EncodecConfig"),
139
+ ("encoder-decoder", "EncoderDecoderConfig"),
140
+ ("eomt", "EomtConfig"),
141
+ ("ernie", "ErnieConfig"),
142
+ ("ernie4_5", "Ernie4_5Config"),
143
+ ("ernie4_5_moe", "Ernie4_5_MoeConfig"),
144
+ ("ernie_m", "ErnieMConfig"),
145
+ ("esm", "EsmConfig"),
146
+ ("evolla", "EvollaConfig"),
147
+ ("exaone4", "Exaone4Config"),
148
+ ("falcon", "FalconConfig"),
149
+ ("falcon_h1", "FalconH1Config"),
150
+ ("falcon_mamba", "FalconMambaConfig"),
151
+ ("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
152
+ ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGanConfig"),
153
+ ("flaubert", "FlaubertConfig"),
154
+ ("flava", "FlavaConfig"),
155
+ ("flex_olmo", "FlexOlmoConfig"),
156
+ ("florence2", "Florence2Config"),
157
+ ("fnet", "FNetConfig"),
158
+ ("focalnet", "FocalNetConfig"),
159
+ ("fsmt", "FSMTConfig"),
160
+ ("funnel", "FunnelConfig"),
161
+ ("fuyu", "FuyuConfig"),
162
+ ("gemma", "GemmaConfig"),
163
+ ("gemma2", "Gemma2Config"),
164
+ ("gemma3", "Gemma3Config"),
165
+ ("gemma3_text", "Gemma3TextConfig"),
166
+ ("gemma3n", "Gemma3nConfig"),
167
+ ("gemma3n_audio", "Gemma3nAudioConfig"),
168
+ ("gemma3n_text", "Gemma3nTextConfig"),
169
+ ("gemma3n_vision", "Gemma3nVisionConfig"),
170
+ ("git", "GitConfig"),
171
+ ("glm", "GlmConfig"),
172
+ ("glm4", "Glm4Config"),
173
+ ("glm4_moe", "Glm4MoeConfig"),
174
+ ("glm4v", "Glm4vConfig"),
175
+ ("glm4v_moe", "Glm4vMoeConfig"),
176
+ ("glm4v_moe_text", "Glm4vMoeTextConfig"),
177
+ ("glm4v_text", "Glm4vTextConfig"),
178
+ ("glpn", "GLPNConfig"),
179
+ ("got_ocr2", "GotOcr2Config"),
180
+ ("gpt-sw3", "GPT2Config"),
181
+ ("gpt2", "GPT2Config"),
182
+ ("gpt_bigcode", "GPTBigCodeConfig"),
183
+ ("gpt_neo", "GPTNeoConfig"),
184
+ ("gpt_neox", "GPTNeoXConfig"),
185
+ ("gpt_neox_japanese", "GPTNeoXJapaneseConfig"),
186
+ ("gpt_oss", "GptOssConfig"),
187
+ ("gptj", "GPTJConfig"),
188
+ ("gptsan-japanese", "GPTSanJapaneseConfig"),
189
+ ("granite", "GraniteConfig"),
190
+ ("granite_speech", "GraniteSpeechConfig"),
191
+ ("granitemoe", "GraniteMoeConfig"),
192
+ ("granitemoehybrid", "GraniteMoeHybridConfig"),
193
+ ("granitemoeshared", "GraniteMoeSharedConfig"),
194
+ ("granitevision", "LlavaNextConfig"),
195
+ ("graphormer", "GraphormerConfig"),
196
+ ("grounding-dino", "GroundingDinoConfig"),
197
+ ("groupvit", "GroupViTConfig"),
198
+ ("helium", "HeliumConfig"),
199
+ ("hgnet_v2", "HGNetV2Config"),
200
+ ("hiera", "HieraConfig"),
201
+ ("hubert", "HubertConfig"),
202
+ ("hunyuan_v1_dense", "HunYuanDenseV1Config"),
203
+ ("hunyuan_v1_moe", "HunYuanMoEV1Config"),
204
+ ("ibert", "IBertConfig"),
205
+ ("idefics", "IdeficsConfig"),
206
+ ("idefics2", "Idefics2Config"),
207
+ ("idefics3", "Idefics3Config"),
208
+ ("idefics3_vision", "Idefics3VisionConfig"),
209
+ ("ijepa", "IJepaConfig"),
210
+ ("imagegpt", "ImageGPTConfig"),
211
+ ("informer", "InformerConfig"),
212
+ ("instructblip", "InstructBlipConfig"),
213
+ ("instructblipvideo", "InstructBlipVideoConfig"),
214
+ ("internvl", "InternVLConfig"),
215
+ ("internvl_vision", "InternVLVisionConfig"),
216
+ ("jamba", "JambaConfig"),
217
+ ("janus", "JanusConfig"),
218
+ ("jetmoe", "JetMoeConfig"),
219
+ ("jukebox", "JukeboxConfig"),
220
+ ("kosmos-2", "Kosmos2Config"),
221
+ ("kosmos-2.5", "Kosmos2_5Config"),
222
+ ("kyutai_speech_to_text", "KyutaiSpeechToTextConfig"),
223
+ ("layoutlm", "LayoutLMConfig"),
224
+ ("layoutlmv2", "LayoutLMv2Config"),
225
+ ("layoutlmv3", "LayoutLMv3Config"),
226
+ ("led", "LEDConfig"),
227
+ ("levit", "LevitConfig"),
228
+ ("lfm2", "Lfm2Config"),
229
+ ("lfm2_vl", "Lfm2VlConfig"),
230
+ ("lightglue", "LightGlueConfig"),
231
+ ("lilt", "LiltConfig"),
232
+ ("llama", "LlamaConfig"),
233
+ ("llama4", "Llama4Config"),
234
+ ("llama4_text", "Llama4TextConfig"),
235
+ ("llava", "LlavaConfig"),
236
+ ("llava_next", "LlavaNextConfig"),
237
+ ("llava_next_video", "LlavaNextVideoConfig"),
238
+ ("llava_onevision", "LlavaOnevisionConfig"),
239
+ ("longcat_flash", "LongcatFlashConfig"),
240
+ ("longformer", "LongformerConfig"),
241
+ ("longt5", "LongT5Config"),
242
+ ("luke", "LukeConfig"),
243
+ ("lxmert", "LxmertConfig"),
244
+ ("m2m_100", "M2M100Config"),
245
+ ("mamba", "MambaConfig"),
246
+ ("mamba2", "Mamba2Config"),
247
+ ("marian", "MarianConfig"),
248
+ ("markuplm", "MarkupLMConfig"),
249
+ ("mask2former", "Mask2FormerConfig"),
250
+ ("maskformer", "MaskFormerConfig"),
251
+ ("maskformer-swin", "MaskFormerSwinConfig"),
252
+ ("mbart", "MBartConfig"),
253
+ ("mctct", "MCTCTConfig"),
254
+ ("mega", "MegaConfig"),
255
+ ("megatron-bert", "MegatronBertConfig"),
256
+ ("metaclip_2", "MetaClip2Config"),
257
+ ("mgp-str", "MgpstrConfig"),
258
+ ("mimi", "MimiConfig"),
259
+ ("minimax", "MiniMaxConfig"),
260
+ ("ministral", "MinistralConfig"),
261
+ ("mistral", "MistralConfig"),
262
+ ("mistral3", "Mistral3Config"),
263
+ ("mixtral", "MixtralConfig"),
264
+ ("mlcd", "MLCDVisionConfig"),
265
+ ("mllama", "MllamaConfig"),
266
+ ("mm-grounding-dino", "MMGroundingDinoConfig"),
267
+ ("mobilebert", "MobileBertConfig"),
268
+ ("mobilenet_v1", "MobileNetV1Config"),
269
+ ("mobilenet_v2", "MobileNetV2Config"),
270
+ ("mobilevit", "MobileViTConfig"),
271
+ ("mobilevitv2", "MobileViTV2Config"),
272
+ ("modernbert", "ModernBertConfig"),
273
+ ("modernbert-decoder", "ModernBertDecoderConfig"),
274
+ ("moonshine", "MoonshineConfig"),
275
+ ("moshi", "MoshiConfig"),
276
+ ("mpnet", "MPNetConfig"),
277
+ ("mpt", "MptConfig"),
278
+ ("mra", "MraConfig"),
279
+ ("mt5", "MT5Config"),
280
+ ("musicgen", "MusicgenConfig"),
281
+ ("musicgen_melody", "MusicgenMelodyConfig"),
282
+ ("mvp", "MvpConfig"),
283
+ ("nat", "NatConfig"),
284
+ ("nemotron", "NemotronConfig"),
285
+ ("nezha", "NezhaConfig"),
286
+ ("nllb-moe", "NllbMoeConfig"),
287
+ ("nougat", "VisionEncoderDecoderConfig"),
288
+ ("nystromformer", "NystromformerConfig"),
289
+ ("olmo", "OlmoConfig"),
290
+ ("olmo2", "Olmo2Config"),
291
+ ("olmo3", "Olmo3Config"),
292
+ ("olmoe", "OlmoeConfig"),
293
+ ("omdet-turbo", "OmDetTurboConfig"),
294
+ ("oneformer", "OneFormerConfig"),
295
+ ("open-llama", "OpenLlamaConfig"),
296
+ ("openai-gpt", "OpenAIGPTConfig"),
297
+ ("opt", "OPTConfig"),
298
+ ("ovis2", "Ovis2Config"),
299
+ ("owlv2", "Owlv2Config"),
300
+ ("owlvit", "OwlViTConfig"),
301
+ ("paligemma", "PaliGemmaConfig"),
302
+ ("parakeet_ctc", "ParakeetCTCConfig"),
303
+ ("parakeet_encoder", "ParakeetEncoderConfig"),
304
+ ("patchtsmixer", "PatchTSMixerConfig"),
305
+ ("patchtst", "PatchTSTConfig"),
306
+ ("pegasus", "PegasusConfig"),
307
+ ("pegasus_x", "PegasusXConfig"),
308
+ ("perceiver", "PerceiverConfig"),
309
+ ("perception_encoder", "TimmWrapperConfig"),
310
+ ("perception_lm", "PerceptionLMConfig"),
311
+ ("persimmon", "PersimmonConfig"),
312
+ ("phi", "PhiConfig"),
313
+ ("phi3", "Phi3Config"),
314
+ ("phi4_multimodal", "Phi4MultimodalConfig"),
315
+ ("phimoe", "PhimoeConfig"),
316
+ ("pix2struct", "Pix2StructConfig"),
317
+ ("pixtral", "PixtralVisionConfig"),
318
+ ("plbart", "PLBartConfig"),
319
+ ("poolformer", "PoolFormerConfig"),
320
+ ("pop2piano", "Pop2PianoConfig"),
321
+ ("prompt_depth_anything", "PromptDepthAnythingConfig"),
322
+ ("prophetnet", "ProphetNetConfig"),
323
+ ("pvt", "PvtConfig"),
324
+ ("pvt_v2", "PvtV2Config"),
325
+ ("qdqbert", "QDQBertConfig"),
326
+ ("qwen2", "Qwen2Config"),
327
+ ("qwen2_5_omni", "Qwen2_5OmniConfig"),
328
+ ("qwen2_5_vl", "Qwen2_5_VLConfig"),
329
+ ("qwen2_5_vl_text", "Qwen2_5_VLTextConfig"),
330
+ ("qwen2_audio", "Qwen2AudioConfig"),
331
+ ("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"),
332
+ ("qwen2_moe", "Qwen2MoeConfig"),
333
+ ("qwen2_vl", "Qwen2VLConfig"),
334
+ ("qwen2_vl_text", "Qwen2VLTextConfig"),
335
+ ("qwen3", "Qwen3Config"),
336
+ ("qwen3_moe", "Qwen3MoeConfig"),
337
+ ("qwen3_next", "Qwen3NextConfig"),
338
+ ("qwen3_omni_moe", "Qwen3OmniMoeConfig"),
339
+ ("qwen3_vl", "Qwen3VLConfig"),
340
+ ("qwen3_vl_moe", "Qwen3VLMoeConfig"),
341
+ ("qwen3_vl_moe_text", "Qwen3VLMoeTextConfig"),
342
+ ("qwen3_vl_text", "Qwen3VLTextConfig"),
343
+ ("rag", "RagConfig"),
344
+ ("realm", "RealmConfig"),
345
+ ("recurrent_gemma", "RecurrentGemmaConfig"),
346
+ ("reformer", "ReformerConfig"),
347
+ ("regnet", "RegNetConfig"),
348
+ ("rembert", "RemBertConfig"),
349
+ ("resnet", "ResNetConfig"),
350
+ ("retribert", "RetriBertConfig"),
351
+ ("roberta", "RobertaConfig"),
352
+ ("roberta-prelayernorm", "RobertaPreLayerNormConfig"),
353
+ ("roc_bert", "RoCBertConfig"),
354
+ ("roformer", "RoFormerConfig"),
355
+ ("rt_detr", "RTDetrConfig"),
356
+ ("rt_detr_resnet", "RTDetrResNetConfig"),
357
+ ("rt_detr_v2", "RTDetrV2Config"),
358
+ ("rwkv", "RwkvConfig"),
359
+ ("sam", "SamConfig"),
360
+ ("sam2", "Sam2Config"),
361
+ ("sam2_hiera_det_model", "Sam2HieraDetConfig"),
362
+ ("sam2_video", "Sam2VideoConfig"),
363
+ ("sam2_vision_model", "Sam2VisionConfig"),
364
+ ("sam_hq", "SamHQConfig"),
365
+ ("sam_hq_vision_model", "SamHQVisionConfig"),
366
+ ("sam_vision_model", "SamVisionConfig"),
367
+ ("seamless_m4t", "SeamlessM4TConfig"),
368
+ ("seamless_m4t_v2", "SeamlessM4Tv2Config"),
369
+ ("seed_oss", "SeedOssConfig"),
370
+ ("segformer", "SegformerConfig"),
371
+ ("seggpt", "SegGptConfig"),
372
+ ("sew", "SEWConfig"),
373
+ ("sew-d", "SEWDConfig"),
374
+ ("shieldgemma2", "ShieldGemma2Config"),
375
+ ("siglip", "SiglipConfig"),
376
+ ("siglip2", "Siglip2Config"),
377
+ ("siglip2_vision_model", "Siglip2VisionConfig"),
378
+ ("siglip_vision_model", "SiglipVisionConfig"),
379
+ ("smollm3", "SmolLM3Config"),
380
+ ("smolvlm", "SmolVLMConfig"),
381
+ ("smolvlm_vision", "SmolVLMVisionConfig"),
382
+ ("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
383
+ ("speech_to_text", "Speech2TextConfig"),
384
+ ("speech_to_text_2", "Speech2Text2Config"),
385
+ ("speecht5", "SpeechT5Config"),
386
+ ("splinter", "SplinterConfig"),
387
+ ("squeezebert", "SqueezeBertConfig"),
388
+ ("stablelm", "StableLmConfig"),
389
+ ("starcoder2", "Starcoder2Config"),
390
+ ("superglue", "SuperGlueConfig"),
391
+ ("superpoint", "SuperPointConfig"),
392
+ ("swiftformer", "SwiftFormerConfig"),
393
+ ("swin", "SwinConfig"),
394
+ ("swin2sr", "Swin2SRConfig"),
395
+ ("swinv2", "Swinv2Config"),
396
+ ("switch_transformers", "SwitchTransformersConfig"),
397
+ ("t5", "T5Config"),
398
+ ("t5gemma", "T5GemmaConfig"),
399
+ ("table-transformer", "TableTransformerConfig"),
400
+ ("tapas", "TapasConfig"),
401
+ ("textnet", "TextNetConfig"),
402
+ ("time_series_transformer", "TimeSeriesTransformerConfig"),
403
+ ("timesfm", "TimesFmConfig"),
404
+ ("timesformer", "TimesformerConfig"),
405
+ ("timm_backbone", "TimmBackboneConfig"),
406
+ ("timm_wrapper", "TimmWrapperConfig"),
407
+ ("trajectory_transformer", "TrajectoryTransformerConfig"),
408
+ ("transfo-xl", "TransfoXLConfig"),
409
+ ("trocr", "TrOCRConfig"),
410
+ ("tvlt", "TvltConfig"),
411
+ ("tvp", "TvpConfig"),
412
+ ("udop", "UdopConfig"),
413
+ ("umt5", "UMT5Config"),
414
+ ("unispeech", "UniSpeechConfig"),
415
+ ("unispeech-sat", "UniSpeechSatConfig"),
416
+ ("univnet", "UnivNetConfig"),
417
+ ("upernet", "UperNetConfig"),
418
+ ("van", "VanConfig"),
419
+ ("vaultgemma", "VaultGemmaConfig"),
420
+ ("video_llava", "VideoLlavaConfig"),
421
+ ("videomae", "VideoMAEConfig"),
422
+ ("vilt", "ViltConfig"),
423
+ ("vipllava", "VipLlavaConfig"),
424
+ ("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
425
+ ("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
426
+ ("visual_bert", "VisualBertConfig"),
427
+ ("vit", "ViTConfig"),
428
+ ("vit_hybrid", "ViTHybridConfig"),
429
+ ("vit_mae", "ViTMAEConfig"),
430
+ ("vit_msn", "ViTMSNConfig"),
431
+ ("vitdet", "VitDetConfig"),
432
+ ("vitmatte", "VitMatteConfig"),
433
+ ("vitpose", "VitPoseConfig"),
434
+ ("vitpose_backbone", "VitPoseBackboneConfig"),
435
+ ("vits", "VitsConfig"),
436
+ ("vivit", "VivitConfig"),
437
+ ("vjepa2", "VJEPA2Config"),
438
+ ("voxtral", "VoxtralConfig"),
439
+ ("voxtral_encoder", "VoxtralEncoderConfig"),
440
+ ("wav2vec2", "Wav2Vec2Config"),
441
+ ("wav2vec2-bert", "Wav2Vec2BertConfig"),
442
+ ("wav2vec2-conformer", "Wav2Vec2ConformerConfig"),
443
+ ("wavlm", "WavLMConfig"),
444
+ ("whisper", "WhisperConfig"),
445
+ ("xclip", "XCLIPConfig"),
446
+ ("xcodec", "XcodecConfig"),
447
+ ("xglm", "XGLMConfig"),
448
+ ("xlm", "XLMConfig"),
449
+ ("xlm-prophetnet", "XLMProphetNetConfig"),
450
+ ("xlm-roberta", "XLMRobertaConfig"),
451
+ ("xlm-roberta-xl", "XLMRobertaXLConfig"),
452
+ ("xlnet", "XLNetConfig"),
453
+ ("xlstm", "xLSTMConfig"),
454
+ ("xmod", "XmodConfig"),
455
+ ("yolos", "YolosConfig"),
456
+ ("yoso", "YosoConfig"),
457
+ ("zamba", "ZambaConfig"),
458
+ ("zamba2", "Zamba2Config"),
459
+ ("zoedepth", "ZoeDepthConfig"),
460
+ ]
461
+ )
462
+
463
+
464
+ MODEL_NAMES_MAPPING = OrderedDict[str, str](
465
+ [
466
+ # Add full (and cased) model names here
467
+ ("aimv2", "AIMv2"),
468
+ ("aimv2_vision_model", "Aimv2VisionModel"),
469
+ ("albert", "ALBERT"),
470
+ ("align", "ALIGN"),
471
+ ("altclip", "AltCLIP"),
472
+ ("apertus", "Apertus"),
473
+ ("arcee", "Arcee"),
474
+ ("aria", "Aria"),
475
+ ("aria_text", "AriaText"),
476
+ ("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
477
+ ("autoformer", "Autoformer"),
478
+ ("aya_vision", "AyaVision"),
479
+ ("bamba", "Bamba"),
480
+ ("bark", "Bark"),
481
+ ("bart", "BART"),
482
+ ("barthez", "BARThez"),
483
+ ("bartpho", "BARTpho"),
484
+ ("beit", "BEiT"),
485
+ ("bert", "BERT"),
486
+ ("bert-generation", "Bert Generation"),
487
+ ("bert-japanese", "BertJapanese"),
488
+ ("bertweet", "BERTweet"),
489
+ ("big_bird", "BigBird"),
490
+ ("bigbird_pegasus", "BigBird-Pegasus"),
491
+ ("biogpt", "BioGpt"),
492
+ ("bit", "BiT"),
493
+ ("bitnet", "BitNet"),
494
+ ("blenderbot", "Blenderbot"),
495
+ ("blenderbot-small", "BlenderbotSmall"),
496
+ ("blip", "BLIP"),
497
+ ("blip-2", "BLIP-2"),
498
+ ("blip_2_qformer", "BLIP-2 QFormer"),
499
+ ("bloom", "BLOOM"),
500
+ ("blt", "Blt"),
501
+ ("bort", "BORT"),
502
+ ("bridgetower", "BridgeTower"),
503
+ ("bros", "BROS"),
504
+ ("byt5", "ByT5"),
505
+ ("camembert", "CamemBERT"),
506
+ ("canine", "CANINE"),
507
+ ("chameleon", "Chameleon"),
508
+ ("chinese_clip", "Chinese-CLIP"),
509
+ ("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
510
+ ("clap", "CLAP"),
511
+ ("clip", "CLIP"),
512
+ ("clip_text_model", "CLIPTextModel"),
513
+ ("clip_vision_model", "CLIPVisionModel"),
514
+ ("clipseg", "CLIPSeg"),
515
+ ("clvp", "CLVP"),
516
+ ("code_llama", "CodeLlama"),
517
+ ("codegen", "CodeGen"),
518
+ ("cohere", "Cohere"),
519
+ ("cohere2", "Cohere2"),
520
+ ("cohere2_vision", "Cohere2Vision"),
521
+ ("colpali", "ColPali"),
522
+ ("colqwen2", "ColQwen2"),
523
+ ("conditional_detr", "Conditional DETR"),
524
+ ("convbert", "ConvBERT"),
525
+ ("convnext", "ConvNeXT"),
526
+ ("convnextv2", "ConvNeXTV2"),
527
+ ("cpm", "CPM"),
528
+ ("cpmant", "CPM-Ant"),
529
+ ("csm", "CSM"),
530
+ ("ctrl", "CTRL"),
531
+ ("cvt", "CvT"),
532
+ ("d_fine", "D-FINE"),
533
+ ("dab-detr", "DAB-DETR"),
534
+ ("dac", "DAC"),
535
+ ("data2vec-audio", "Data2VecAudio"),
536
+ ("data2vec-text", "Data2VecText"),
537
+ ("data2vec-vision", "Data2VecVision"),
538
+ ("dbrx", "DBRX"),
539
+ ("deberta", "DeBERTa"),
540
+ ("deberta-v2", "DeBERTa-v2"),
541
+ ("decision_transformer", "Decision Transformer"),
542
+ ("deepseek_v2", "DeepSeek-V2"),
543
+ ("deepseek_v3", "DeepSeek-V3"),
544
+ ("deepseek_vl", "DeepseekVL"),
545
+ ("deepseek_vl_hybrid", "DeepseekVLHybrid"),
546
+ ("deformable_detr", "Deformable DETR"),
547
+ ("deit", "DeiT"),
548
+ ("deplot", "DePlot"),
549
+ ("depth_anything", "Depth Anything"),
550
+ ("depth_anything_v2", "Depth Anything V2"),
551
+ ("depth_pro", "DepthPro"),
552
+ ("deta", "DETA"),
553
+ ("detr", "DETR"),
554
+ ("dia", "Dia"),
555
+ ("dialogpt", "DialoGPT"),
556
+ ("diffllama", "DiffLlama"),
557
+ ("dinat", "DiNAT"),
558
+ ("dinov2", "DINOv2"),
559
+ ("dinov2_with_registers", "DINOv2 with Registers"),
560
+ ("dinov3_convnext", "DINOv3 ConvNext"),
561
+ ("dinov3_vit", "DINOv3 ViT"),
562
+ ("distilbert", "DistilBERT"),
563
+ ("dit", "DiT"),
564
+ ("doge", "Doge"),
565
+ ("donut-swin", "DonutSwin"),
566
+ ("dots1", "dots1"),
567
+ ("dpr", "DPR"),
568
+ ("dpt", "DPT"),
569
+ ("edgetam", "EdgeTAM"),
570
+ ("edgetam_video", "EdgeTamVideo"),
571
+ ("edgetam_vision_model", "EdgeTamVisionModel"),
572
+ ("efficientformer", "EfficientFormer"),
573
+ ("efficientloftr", "EfficientLoFTR"),
574
+ ("efficientnet", "EfficientNet"),
575
+ ("electra", "ELECTRA"),
576
+ ("emu3", "Emu3"),
577
+ ("encodec", "EnCodec"),
578
+ ("encoder-decoder", "Encoder decoder"),
579
+ ("eomt", "EoMT"),
580
+ ("ernie", "ERNIE"),
581
+ ("ernie4_5", "Ernie4_5"),
582
+ ("ernie4_5_moe", "Ernie4_5_MoE"),
583
+ ("ernie_m", "ErnieM"),
584
+ ("esm", "ESM"),
585
+ ("evolla", "Evolla"),
586
+ ("exaone4", "EXAONE-4.0"),
587
+ ("falcon", "Falcon"),
588
+ ("falcon3", "Falcon3"),
589
+ ("falcon_h1", "FalconH1"),
590
+ ("falcon_mamba", "FalconMamba"),
591
+ ("fastspeech2_conformer", "FastSpeech2Conformer"),
592
+ ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
593
+ ("flan-t5", "FLAN-T5"),
594
+ ("flan-ul2", "FLAN-UL2"),
595
+ ("flaubert", "FlauBERT"),
596
+ ("flava", "FLAVA"),
597
+ ("flex_olmo", "FlexOlmo"),
598
+ ("florence2", "Florence2"),
599
+ ("fnet", "FNet"),
600
+ ("focalnet", "FocalNet"),
601
+ ("fsmt", "FairSeq Machine-Translation"),
602
+ ("funnel", "Funnel Transformer"),
603
+ ("fuyu", "Fuyu"),
604
+ ("gemma", "Gemma"),
605
+ ("gemma2", "Gemma2"),
606
+ ("gemma3", "Gemma3ForConditionalGeneration"),
607
+ ("gemma3_text", "Gemma3ForCausalLM"),
608
+ ("gemma3n", "Gemma3nForConditionalGeneration"),
609
+ ("gemma3n_audio", "Gemma3nAudioEncoder"),
610
+ ("gemma3n_text", "Gemma3nForCausalLM"),
611
+ ("gemma3n_vision", "TimmWrapperModel"),
612
+ ("git", "GIT"),
613
+ ("glm", "GLM"),
614
+ ("glm4", "GLM4"),
615
+ ("glm4_moe", "Glm4MoE"),
616
+ ("glm4v", "GLM4V"),
617
+ ("glm4v_moe", "GLM4VMOE"),
618
+ ("glm4v_moe_text", "GLM4VMOE"),
619
+ ("glm4v_text", "GLM4V"),
620
+ ("glpn", "GLPN"),
621
+ ("got_ocr2", "GOT-OCR2"),
622
+ ("gpt-sw3", "GPT-Sw3"),
623
+ ("gpt2", "OpenAI GPT-2"),
624
+ ("gpt_bigcode", "GPTBigCode"),
625
+ ("gpt_neo", "GPT Neo"),
626
+ ("gpt_neox", "GPT NeoX"),
627
+ ("gpt_neox_japanese", "GPT NeoX Japanese"),
628
+ ("gpt_oss", "GptOss"),
629
+ ("gptj", "GPT-J"),
630
+ ("gptsan-japanese", "GPTSAN-japanese"),
631
+ ("granite", "Granite"),
632
+ ("granite_speech", "GraniteSpeech"),
633
+ ("granitemoe", "GraniteMoeMoe"),
634
+ ("granitemoehybrid", "GraniteMoeHybrid"),
635
+ ("granitemoeshared", "GraniteMoeSharedMoe"),
636
+ ("granitevision", "LLaVA-NeXT"),
637
+ ("graphormer", "Graphormer"),
638
+ ("grounding-dino", "Grounding DINO"),
639
+ ("groupvit", "GroupViT"),
640
+ ("helium", "Helium"),
641
+ ("herbert", "HerBERT"),
642
+ ("hgnet_v2", "HGNet-V2"),
643
+ ("hiera", "Hiera"),
644
+ ("hubert", "Hubert"),
645
+ ("hunyuan_v1_dense", "HunYuanDenseV1"),
646
+ ("hunyuan_v1_moe", "HunYuanMoeV1"),
647
+ ("ibert", "I-BERT"),
648
+ ("idefics", "IDEFICS"),
649
+ ("idefics2", "Idefics2"),
650
+ ("idefics3", "Idefics3"),
651
+ ("idefics3_vision", "Idefics3VisionTransformer"),
652
+ ("ijepa", "I-JEPA"),
653
+ ("imagegpt", "ImageGPT"),
654
+ ("informer", "Informer"),
655
+ ("instructblip", "InstructBLIP"),
656
+ ("instructblipvideo", "InstructBlipVideo"),
657
+ ("internvl", "InternVL"),
658
+ ("internvl_vision", "InternVLVision"),
659
+ ("jamba", "Jamba"),
660
+ ("janus", "Janus"),
661
+ ("jetmoe", "JetMoe"),
662
+ ("jukebox", "Jukebox"),
663
+ ("kosmos-2", "KOSMOS-2"),
664
+ ("kosmos-2.5", "KOSMOS-2.5"),
665
+ ("kyutai_speech_to_text", "KyutaiSpeechToText"),
666
+ ("layoutlm", "LayoutLM"),
667
+ ("layoutlmv2", "LayoutLMv2"),
668
+ ("layoutlmv3", "LayoutLMv3"),
669
+ ("layoutxlm", "LayoutXLM"),
670
+ ("led", "LED"),
671
+ ("levit", "LeViT"),
672
+ ("lfm2", "Lfm2"),
673
+ ("lfm2_vl", "Lfm2Vl"),
674
+ ("lightglue", "LightGlue"),
675
+ ("lilt", "LiLT"),
676
+ ("llama", "LLaMA"),
677
+ ("llama2", "Llama2"),
678
+ ("llama3", "Llama3"),
679
+ ("llama4", "Llama4"),
680
+ ("llama4_text", "Llama4ForCausalLM"),
681
+ ("llava", "LLaVa"),
682
+ ("llava_next", "LLaVA-NeXT"),
683
+ ("llava_next_video", "LLaVa-NeXT-Video"),
684
+ ("llava_onevision", "LLaVA-Onevision"),
685
+ ("longcat_flash", "LongCatFlash"),
686
+ ("longformer", "Longformer"),
687
+ ("longt5", "LongT5"),
688
+ ("luke", "LUKE"),
689
+ ("lxmert", "LXMERT"),
690
+ ("m2m_100", "M2M100"),
691
+ ("madlad-400", "MADLAD-400"),
692
+ ("mamba", "Mamba"),
693
+ ("mamba2", "mamba2"),
694
+ ("marian", "Marian"),
695
+ ("markuplm", "MarkupLM"),
696
+ ("mask2former", "Mask2Former"),
697
+ ("maskformer", "MaskFormer"),
698
+ ("maskformer-swin", "MaskFormerSwin"),
699
+ ("matcha", "MatCha"),
700
+ ("mbart", "mBART"),
701
+ ("mbart50", "mBART-50"),
702
+ ("mctct", "M-CTC-T"),
703
+ ("mega", "MEGA"),
704
+ ("megatron-bert", "Megatron-BERT"),
705
+ ("megatron_gpt2", "Megatron-GPT2"),
706
+ ("metaclip_2", "MetaCLIP 2"),
707
+ ("mgp-str", "MGP-STR"),
708
+ ("mimi", "Mimi"),
709
+ ("minimax", "MiniMax"),
710
+ ("ministral", "Ministral"),
711
+ ("mistral", "Mistral"),
712
+ ("mistral3", "Mistral3"),
713
+ ("mixtral", "Mixtral"),
714
+ ("mlcd", "MLCD"),
715
+ ("mllama", "Mllama"),
716
+ ("mluke", "mLUKE"),
717
+ ("mm-grounding-dino", "MM Grounding DINO"),
718
+ ("mms", "MMS"),
719
+ ("mobilebert", "MobileBERT"),
720
+ ("mobilenet_v1", "MobileNetV1"),
721
+ ("mobilenet_v2", "MobileNetV2"),
722
+ ("mobilevit", "MobileViT"),
723
+ ("mobilevitv2", "MobileViTV2"),
724
+ ("modernbert", "ModernBERT"),
725
+ ("modernbert-decoder", "ModernBertDecoder"),
726
+ ("moonshine", "Moonshine"),
727
+ ("moshi", "Moshi"),
728
+ ("mpnet", "MPNet"),
729
+ ("mpt", "MPT"),
730
+ ("mra", "MRA"),
731
+ ("mt5", "MT5"),
732
+ ("musicgen", "MusicGen"),
733
+ ("musicgen_melody", "MusicGen Melody"),
734
+ ("mvp", "MVP"),
735
+ ("myt5", "myt5"),
736
+ ("nat", "NAT"),
737
+ ("nemotron", "Nemotron"),
738
+ ("nezha", "Nezha"),
739
+ ("nllb", "NLLB"),
740
+ ("nllb-moe", "NLLB-MOE"),
741
+ ("nougat", "Nougat"),
742
+ ("nystromformer", "Nyströmformer"),
743
+ ("olmo", "OLMo"),
744
+ ("olmo2", "OLMo2"),
745
+ ("olmo3", "Olmo3"),
746
+ ("olmoe", "OLMoE"),
747
+ ("omdet-turbo", "OmDet-Turbo"),
748
+ ("oneformer", "OneFormer"),
749
+ ("open-llama", "OpenLlama"),
750
+ ("openai-gpt", "OpenAI GPT"),
751
+ ("opt", "OPT"),
752
+ ("ovis2", "Ovis2"),
753
+ ("owlv2", "OWLv2"),
754
+ ("owlvit", "OWL-ViT"),
755
+ ("paligemma", "PaliGemma"),
756
+ ("parakeet", "Parakeet"),
757
+ ("parakeet_ctc", "Parakeet"),
758
+ ("parakeet_encoder", "ParakeetEncoder"),
759
+ ("patchtsmixer", "PatchTSMixer"),
760
+ ("patchtst", "PatchTST"),
761
+ ("pegasus", "Pegasus"),
762
+ ("pegasus_x", "PEGASUS-X"),
763
+ ("perceiver", "Perceiver"),
764
+ ("perception_encoder", "PerceptionEncoder"),
765
+ ("perception_lm", "PerceptionLM"),
766
+ ("persimmon", "Persimmon"),
767
+ ("phi", "Phi"),
768
+ ("phi3", "Phi3"),
769
+ ("phi4_multimodal", "Phi4Multimodal"),
770
+ ("phimoe", "Phimoe"),
771
+ ("phobert", "PhoBERT"),
772
+ ("pix2struct", "Pix2Struct"),
773
+ ("pixtral", "Pixtral"),
774
+ ("plbart", "PLBart"),
775
+ ("poolformer", "PoolFormer"),
776
+ ("pop2piano", "Pop2Piano"),
777
+ ("prompt_depth_anything", "PromptDepthAnything"),
778
+ ("prophetnet", "ProphetNet"),
779
+ ("pvt", "PVT"),
780
+ ("pvt_v2", "PVTv2"),
781
+ ("qdqbert", "QDQBert"),
782
+ ("qwen2", "Qwen2"),
783
+ ("qwen2_5_omni", "Qwen2_5Omni"),
784
+ ("qwen2_5_vl", "Qwen2_5_VL"),
785
+ ("qwen2_5_vl_text", "Qwen2_5_VL"),
786
+ ("qwen2_audio", "Qwen2Audio"),
787
+ ("qwen2_audio_encoder", "Qwen2AudioEncoder"),
788
+ ("qwen2_moe", "Qwen2MoE"),
789
+ ("qwen2_vl", "Qwen2VL"),
790
+ ("qwen2_vl_text", "Qwen2VL"),
791
+ ("qwen3", "Qwen3"),
792
+ ("qwen3_moe", "Qwen3MoE"),
793
+ ("qwen3_next", "Qwen3Next"),
794
+ ("qwen3_omni_moe", "Qwen3OmniMoE"),
795
+ ("qwen3_vl", "Qwen3VL"),
796
+ ("qwen3_vl_moe", "Qwen3VLMoe"),
797
+ ("qwen3_vl_moe_text", "Qwen3VLMoe"),
798
+ ("qwen3_vl_text", "Qwen3VL"),
799
+ ("rag", "RAG"),
800
+ ("realm", "REALM"),
801
+ ("recurrent_gemma", "RecurrentGemma"),
802
+ ("reformer", "Reformer"),
803
+ ("regnet", "RegNet"),
804
+ ("rembert", "RemBERT"),
805
+ ("resnet", "ResNet"),
806
+ ("retribert", "RetriBERT"),
807
+ ("roberta", "RoBERTa"),
808
+ ("roberta-prelayernorm", "RoBERTa-PreLayerNorm"),
809
+ ("roc_bert", "RoCBert"),
810
+ ("roformer", "RoFormer"),
811
+ ("rt_detr", "RT-DETR"),
812
+ ("rt_detr_resnet", "RT-DETR-ResNet"),
813
+ ("rt_detr_v2", "RT-DETRv2"),
814
+ ("rwkv", "RWKV"),
815
+ ("sam", "SAM"),
816
+ ("sam2", "SAM2"),
817
+ ("sam2_hiera_det_model", "Sam2HieraDetModel"),
818
+ ("sam2_video", "Sam2VideoModel"),
819
+ ("sam2_vision_model", "Sam2VisionModel"),
820
+ ("sam_hq", "SAM-HQ"),
821
+ ("sam_hq_vision_model", "SamHQVisionModel"),
822
+ ("sam_vision_model", "SamVisionModel"),
823
+ ("seamless_m4t", "SeamlessM4T"),
824
+ ("seamless_m4t_v2", "SeamlessM4Tv2"),
825
+ ("seed_oss", "SeedOss"),
826
+ ("segformer", "SegFormer"),
827
+ ("seggpt", "SegGPT"),
828
+ ("sew", "SEW"),
829
+ ("sew-d", "SEW-D"),
830
+ ("shieldgemma2", "Shieldgemma2"),
831
+ ("siglip", "SigLIP"),
832
+ ("siglip2", "SigLIP2"),
833
+ ("siglip2_vision_model", "Siglip2VisionModel"),
834
+ ("siglip_vision_model", "SiglipVisionModel"),
835
+ ("smollm3", "SmolLM3"),
836
+ ("smolvlm", "SmolVLM"),
837
+ ("smolvlm_vision", "SmolVLMVisionTransformer"),
838
+ ("speech-encoder-decoder", "Speech Encoder decoder"),
839
+ ("speech_to_text", "Speech2Text"),
840
+ ("speech_to_text_2", "Speech2Text2"),
841
+ ("speecht5", "SpeechT5"),
842
+ ("splinter", "Splinter"),
843
+ ("squeezebert", "SqueezeBERT"),
844
+ ("stablelm", "StableLm"),
845
+ ("starcoder2", "Starcoder2"),
846
+ ("superglue", "SuperGlue"),
847
+ ("superpoint", "SuperPoint"),
848
+ ("swiftformer", "SwiftFormer"),
849
+ ("swin", "Swin Transformer"),
850
+ ("swin2sr", "Swin2SR"),
851
+ ("swinv2", "Swin Transformer V2"),
852
+ ("switch_transformers", "SwitchTransformers"),
853
+ ("t5", "T5"),
854
+ ("t5gemma", "T5Gemma"),
855
+ ("t5v1.1", "T5v1.1"),
856
+ ("table-transformer", "Table Transformer"),
857
+ ("tapas", "TAPAS"),
858
+ ("tapex", "TAPEX"),
859
+ ("textnet", "TextNet"),
860
+ ("time_series_transformer", "Time Series Transformer"),
861
+ ("timesfm", "TimesFm"),
862
+ ("timesformer", "TimeSformer"),
863
+ ("timm_backbone", "TimmBackbone"),
864
+ ("timm_wrapper", "TimmWrapperModel"),
865
+ ("trajectory_transformer", "Trajectory Transformer"),
866
+ ("transfo-xl", "Transformer-XL"),
867
+ ("trocr", "TrOCR"),
868
+ ("tvlt", "TVLT"),
869
+ ("tvp", "TVP"),
870
+ ("udop", "UDOP"),
871
+ ("ul2", "UL2"),
872
+ ("umt5", "UMT5"),
873
+ ("unispeech", "UniSpeech"),
874
+ ("unispeech-sat", "UniSpeechSat"),
875
+ ("univnet", "UnivNet"),
876
+ ("upernet", "UPerNet"),
877
+ ("van", "VAN"),
878
+ ("vaultgemma", "VaultGemma"),
879
+ ("video_llava", "VideoLlava"),
880
+ ("videomae", "VideoMAE"),
881
+ ("vilt", "ViLT"),
882
+ ("vipllava", "VipLlava"),
883
+ ("vision-encoder-decoder", "Vision Encoder decoder"),
884
+ ("vision-text-dual-encoder", "VisionTextDualEncoder"),
885
+ ("visual_bert", "VisualBERT"),
886
+ ("vit", "ViT"),
887
+ ("vit_hybrid", "ViT Hybrid"),
888
+ ("vit_mae", "ViTMAE"),
889
+ ("vit_msn", "ViTMSN"),
890
+ ("vitdet", "VitDet"),
891
+ ("vitmatte", "ViTMatte"),
892
+ ("vitpose", "ViTPose"),
893
+ ("vitpose_backbone", "ViTPoseBackbone"),
894
+ ("vits", "VITS"),
895
+ ("vivit", "ViViT"),
896
+ ("vjepa2", "VJEPA2Model"),
897
+ ("voxtral", "Voxtral"),
898
+ ("voxtral_encoder", "Voxtral Encoder"),
899
+ ("wav2vec2", "Wav2Vec2"),
900
+ ("wav2vec2-bert", "Wav2Vec2-BERT"),
901
+ ("wav2vec2-conformer", "Wav2Vec2-Conformer"),
902
+ ("wav2vec2_phoneme", "Wav2Vec2Phoneme"),
903
+ ("wavlm", "WavLM"),
904
+ ("whisper", "Whisper"),
905
+ ("xclip", "X-CLIP"),
906
+ ("xcodec", "X-CODEC"),
907
+ ("xglm", "XGLM"),
908
+ ("xlm", "XLM"),
909
+ ("xlm-prophetnet", "XLM-ProphetNet"),
910
+ ("xlm-roberta", "XLM-RoBERTa"),
911
+ ("xlm-roberta-xl", "XLM-RoBERTa-XL"),
912
+ ("xlm-v", "XLM-V"),
913
+ ("xlnet", "XLNet"),
914
+ ("xls_r", "XLS-R"),
915
+ ("xlsr_wav2vec2", "XLSR-Wav2Vec2"),
916
+ ("xlstm", "xLSTM"),
917
+ ("xmod", "X-MOD"),
918
+ ("yolos", "YOLOS"),
919
+ ("yoso", "YOSO"),
920
+ ("zamba", "Zamba"),
921
+ ("zamba2", "Zamba2"),
922
+ ("zoedepth", "ZoeDepth"),
923
+ ]
924
+ )
925
+
926
+ # This is tied to the processing `-` -> `_` in `model_type_to_module_name`. For example, instead of putting
927
+ # `transfo-xl` (as in `CONFIG_MAPPING_NAMES`), we should use `transfo_xl`.
928
+ DEPRECATED_MODELS = [
929
+ "bort",
930
+ "deta",
931
+ "efficientformer",
932
+ "ernie_m",
933
+ "gptsan_japanese",
934
+ "graphormer",
935
+ "jukebox",
936
+ "mctct",
937
+ "mega",
938
+ "mmbt",
939
+ "nat",
940
+ "nezha",
941
+ "open_llama",
942
+ "qdqbert",
943
+ "realm",
944
+ "retribert",
945
+ "speech_to_text_2",
946
+ "tapex",
947
+ "trajectory_transformer",
948
+ "transfo_xl",
949
+ "tvlt",
950
+ "van",
951
+ "vit_hybrid",
952
+ "xlm_prophetnet",
953
+ ]
954
+
955
+ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
956
+ [
957
+ ("openai-gpt", "openai"),
958
+ ("data2vec-audio", "data2vec"),
959
+ ("data2vec-text", "data2vec"),
960
+ ("data2vec-vision", "data2vec"),
961
+ ("donut-swin", "donut"),
962
+ ("kosmos-2", "kosmos2"),
963
+ ("kosmos-2.5", "kosmos2_5"),
964
+ ("maskformer-swin", "maskformer"),
965
+ ("xclip", "x_clip"),
966
+ ("clip_vision_model", "clip"),
967
+ ("qwen2_audio_encoder", "qwen2_audio"),
968
+ ("voxtral_encoder", "voxtral"),
969
+ ("clip_text_model", "clip"),
970
+ ("aria_text", "aria"),
971
+ ("gemma3_text", "gemma3"),
972
+ ("gemma3n_audio", "gemma3n"),
973
+ ("gemma3n_text", "gemma3n"),
974
+ ("gemma3n_vision", "gemma3n"),
975
+ ("glm4v_text", "glm4v"),
976
+ ("glm4v_moe_text", "glm4v_moe"),
977
+ ("idefics3_vision", "idefics3"),
978
+ ("siglip_vision_model", "siglip"),
979
+ ("siglip2_vision_model", "siglip2"),
980
+ ("aimv2_vision_model", "aimv2"),
981
+ ("smolvlm_vision", "smolvlm"),
982
+ ("chinese_clip_vision_model", "chinese_clip"),
983
+ ("rt_detr_resnet", "rt_detr"),
984
+ ("granitevision", "llava_next"),
985
+ ("internvl_vision", "internvl"),
986
+ ("qwen2_5_vl_text", "qwen2_5_vl"),
987
+ ("qwen2_vl_text", "qwen2_vl"),
988
+ ("qwen3_vl_text", "qwen3_vl"),
989
+ ("qwen3_vl_moe_text", "qwen3_vl_moe"),
990
+ ("sam_vision_model", "sam"),
991
+ ("sam2_vision_model", "sam2"),
992
+ ("edgetam_vision_model", "edgetam"),
993
+ ("sam2_hiera_det_model", "sam2"),
994
+ ("sam_hq_vision_model", "sam_hq"),
995
+ ("llama4_text", "llama4"),
996
+ ("blip_2_qformer", "blip_2"),
997
+ ("fastspeech2_conformer_with_hifigan", "fastspeech2_conformer"),
998
+ ("perception_encoder", "perception_lm"),
999
+ ("parakeet_encoder", "parakeet"),
1000
+ ("parakeet_ctc", "parakeet"),
1001
+ ]
1002
+ )
1003
+
1004
+
1005
+ def model_type_to_module_name(key) -> str:
1006
+ """Converts a config key to the corresponding module."""
1007
+ # Special treatment
1008
+ if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
1009
+ key = SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]
1010
+
1011
+ if key in DEPRECATED_MODELS:
1012
+ key = f"deprecated.{key}"
1013
+ return key
1014
+
1015
+ key = key.replace("-", "_")
1016
+ if key in DEPRECATED_MODELS:
1017
+ key = f"deprecated.{key}"
1018
+
1019
+ return key
1020
+
1021
+
1022
+ def config_class_to_model_type(config) -> Union[str, None]:
1023
+ """Converts a config class name to the corresponding model type"""
1024
+ for key, cls in CONFIG_MAPPING_NAMES.items():
1025
+ if cls == config:
1026
+ return key
1027
+ # if key not found check in extra content
1028
+ for key, cls in CONFIG_MAPPING._extra_content.items():
1029
+ if cls.__name__ == config:
1030
+ return key
1031
+ return None
1032
+
1033
+
1034
+ class _LazyConfigMapping(OrderedDict[str, type[PretrainedConfig]]):
1035
+ """
1036
+ A dictionary that lazily load its values when they are requested.
1037
+ """
1038
+
1039
+ def __init__(self, mapping) -> None:
1040
+ self._mapping = mapping
1041
+ self._extra_content = {}
1042
+ self._modules = {}
1043
+
1044
+ def __getitem__(self, key: str) -> type[PretrainedConfig]:
1045
+ if key in self._extra_content:
1046
+ return self._extra_content[key]
1047
+ if key not in self._mapping:
1048
+ raise KeyError(key)
1049
+ value = self._mapping[key]
1050
+ module_name = model_type_to_module_name(key)
1051
+ if module_name not in self._modules:
1052
+ self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
1053
+ if hasattr(self._modules[module_name], value):
1054
+ return getattr(self._modules[module_name], value)
1055
+
1056
+ # Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the
1057
+ # object at the top level.
1058
+ transformers_module = importlib.import_module("transformers")
1059
+ return getattr(transformers_module, value)
1060
+
1061
+ def keys(self) -> list[str]:
1062
+ return list(self._mapping.keys()) + list(self._extra_content.keys())
1063
+
1064
+ def values(self) -> list[type[PretrainedConfig]]:
1065
+ return [self[k] for k in self._mapping] + list(self._extra_content.values())
1066
+
1067
+ def items(self) -> list[tuple[str, type[PretrainedConfig]]]:
1068
+ return [(k, self[k]) for k in self._mapping] + list(self._extra_content.items())
1069
+
1070
+ def __iter__(self) -> Iterator[str]:
1071
+ return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
1072
+
1073
+ def __contains__(self, item: object) -> bool:
1074
+ return item in self._mapping or item in self._extra_content
1075
+
1076
+ def register(self, key: str, value: type[PretrainedConfig], exist_ok=False) -> None:
1077
+ """
1078
+ Register a new configuration in this mapping.
1079
+ """
1080
+ if key in self._mapping and not exist_ok:
1081
+ raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
1082
+ self._extra_content[key] = value
1083
+
1084
+
1085
+ CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
1086
+
1087
+
1088
+ class _LazyLoadAllMappings(OrderedDict[str, str]):
1089
+ """
1090
+ A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
1091
+ etc.)
1092
+
1093
+ Args:
1094
+ mapping: The mapping to load.
1095
+ """
1096
+
1097
+ def __init__(self, mapping):
1098
+ self._mapping = mapping
1099
+ self._initialized = False
1100
+ self._data = {}
1101
+
1102
+ def _initialize(self):
1103
+ if self._initialized:
1104
+ return
1105
+
1106
+ for model_type, map_name in self._mapping.items():
1107
+ module_name = model_type_to_module_name(model_type)
1108
+ module = importlib.import_module(f".{module_name}", "transformers.models")
1109
+ mapping = getattr(module, map_name)
1110
+ self._data.update(mapping)
1111
+
1112
+ self._initialized = True
1113
+
1114
+ def __getitem__(self, key):
1115
+ self._initialize()
1116
+ return self._data[key]
1117
+
1118
+ def keys(self) -> KeysView[str]:
1119
+ self._initialize()
1120
+ return self._data.keys()
1121
+
1122
+ def values(self) -> ValuesView[str]:
1123
+ self._initialize()
1124
+ return self._data.values()
1125
+
1126
+ def items(self) -> KeysView[str]:
1127
+ self._initialize()
1128
+ return self._data.keys()
1129
+
1130
+ def __iter__(self) -> Iterator[str]:
1131
+ self._initialize()
1132
+ return iter(self._data)
1133
+
1134
+ def __contains__(self, item: object) -> bool:
1135
+ self._initialize()
1136
+ return item in self._data
1137
+
1138
+
1139
+ def _get_class_name(model_class: Union[str, list[str]]):
1140
+ if isinstance(model_class, (list, tuple)):
1141
+ return " or ".join([f"[`{c}`]" for c in model_class if c is not None])
1142
+ return f"[`{model_class}`]"
1143
+
1144
+
1145
+ def _list_model_options(indent, config_to_class=None, use_model_types=True):
1146
+ if config_to_class is None and not use_model_types:
1147
+ raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
1148
+ if use_model_types:
1149
+ if config_to_class is None:
1150
+ model_type_to_name = {model_type: f"[`{config}`]" for model_type, config in CONFIG_MAPPING_NAMES.items()}
1151
+ else:
1152
+ model_type_to_name = {
1153
+ model_type: _get_class_name(model_class)
1154
+ for model_type, model_class in config_to_class.items()
1155
+ if model_type in MODEL_NAMES_MAPPING
1156
+ }
1157
+ lines = [
1158
+ f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
1159
+ for model_type in sorted(model_type_to_name.keys())
1160
+ ]
1161
+ else:
1162
+ config_to_name = {
1163
+ CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
1164
+ for config, clas in config_to_class.items()
1165
+ if config in CONFIG_MAPPING_NAMES
1166
+ }
1167
+ config_to_model_name = {
1168
+ config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
1169
+ }
1170
+ lines = [
1171
+ f"{indent}- [`{config_name}`] configuration class:"
1172
+ f" {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
1173
+ for config_name in sorted(config_to_name.keys())
1174
+ ]
1175
+ return "\n".join(lines)
1176
+
1177
+
1178
+ def replace_list_option_in_docstrings(
1179
+ config_to_class=None, use_model_types: bool = True
1180
+ ) -> Callable[[_CallableT], _CallableT]:
1181
+ def docstring_decorator(fn):
1182
+ docstrings = fn.__doc__
1183
+ if docstrings is None:
1184
+ # Example: -OO
1185
+ return fn
1186
+ lines = docstrings.split("\n")
1187
+ i = 0
1188
+ while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None:
1189
+ i += 1
1190
+ if i < len(lines):
1191
+ indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0]
1192
+ if use_model_types:
1193
+ indent = f"{indent} "
1194
+ lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types)
1195
+ docstrings = "\n".join(lines)
1196
+ else:
1197
+ raise ValueError(
1198
+ f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current"
1199
+ f" docstring is:\n{docstrings}"
1200
+ )
1201
+ fn.__doc__ = docstrings
1202
+ return fn
1203
+
1204
+ return docstring_decorator
1205
+
1206
+
1207
+ class AutoConfig:
1208
+ r"""
1209
+ This is a generic configuration class that will be instantiated as one of the configuration classes of the library
1210
+ when created with the [`~AutoConfig.from_pretrained`] class method.
1211
+
1212
+ This class cannot be instantiated directly using `__init__()` (throws an error).
1213
+ """
1214
+
1215
+ def __init__(self) -> None:
1216
+ raise OSError(
1217
+ "AutoConfig is designed to be instantiated "
1218
+ "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
1219
+ )
1220
+
1221
+ @classmethod
1222
+ def for_model(cls, model_type: str, *args, **kwargs) -> PretrainedConfig:
1223
+ if model_type in CONFIG_MAPPING:
1224
+ config_class = CONFIG_MAPPING[model_type]
1225
+ return config_class(*args, **kwargs)
1226
+ raise ValueError(
1227
+ f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}"
1228
+ )
1229
+
1230
+ @classmethod
1231
+ @replace_list_option_in_docstrings()
1232
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], **kwargs):
1233
+ r"""
1234
+ Instantiate one of the configuration classes of the library from a pretrained model configuration.
1235
+
1236
+ The configuration class to instantiate is selected based on the `model_type` property of the config object that
1237
+ is loaded, or when it's missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
1238
+
1239
+ List options
1240
+
1241
+ Args:
1242
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
1243
+ Can be either:
1244
+
1245
+ - A string, the *model id* of a pretrained model configuration hosted inside a model repo on
1246
+ huggingface.co.
1247
+ - A path to a *directory* containing a configuration file saved using the
1248
+ [`~PretrainedConfig.save_pretrained`] method, or the [`~PreTrainedModel.save_pretrained`] method,
1249
+ e.g., `./my_model_directory/`.
1250
+ - A path or url to a saved configuration JSON *file*, e.g.,
1251
+ `./my_model_directory/configuration.json`.
1252
+ cache_dir (`str` or `os.PathLike`, *optional*):
1253
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
1254
+ standard cache should not be used.
1255
+ force_download (`bool`, *optional*, defaults to `False`):
1256
+ Whether or not to force the (re-)download the model weights and configuration files and override the
1257
+ cached versions if they exist.
1258
+ resume_download:
1259
+ Deprecated and ignored. All downloads are now resumed by default when possible.
1260
+ Will be removed in v5 of Transformers.
1261
+ proxies (`dict[str, str]`, *optional*):
1262
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
1263
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1264
+ revision (`str`, *optional*, defaults to `"main"`):
1265
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
1266
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
1267
+ identifier allowed by git.
1268
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
1269
+ If `False`, then this function returns just the final configuration object.
1270
+
1271
+ If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
1272
+ dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
1273
+ part of `kwargs` which has not been used to update `config` and is otherwise ignored.
1274
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
1275
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
1276
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
1277
+ execute code present on the Hub on your local machine.
1278
+ kwargs(additional keyword arguments, *optional*):
1279
+ The values in kwargs of any keys which are configuration attributes will be used to override the loaded
1280
+ values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
1281
+ by the `return_unused_kwargs` keyword parameter.
1282
+
1283
+ Examples:
1284
+
1285
+ ```python
1286
+ >>> from transformers import AutoConfig
1287
+
1288
+ >>> # Download configuration from huggingface.co and cache.
1289
+ >>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased")
1290
+
1291
+ >>> # Download configuration from huggingface.co (user-uploaded) and cache.
1292
+ >>> config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased")
1293
+
1294
+ >>> # If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*).
1295
+ >>> config = AutoConfig.from_pretrained("./test/bert_saved_model/")
1296
+
1297
+ >>> # Load a specific configuration file.
1298
+ >>> config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json")
1299
+
1300
+ >>> # Change some config attributes when loading a pretrained config.
1301
+ >>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
1302
+ >>> config.output_attentions
1303
+ True
1304
+
1305
+ >>> config, unused_kwargs = AutoConfig.from_pretrained(
1306
+ ... "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
1307
+ ... )
1308
+ >>> config.output_attentions
1309
+ True
1310
+
1311
+ >>> unused_kwargs
1312
+ {'foo': False}
1313
+ ```
1314
+ """
1315
+ use_auth_token = kwargs.pop("use_auth_token", None)
1316
+ if use_auth_token is not None:
1317
+ warnings.warn(
1318
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
1319
+ FutureWarning,
1320
+ )
1321
+ if kwargs.get("token") is not None:
1322
+ raise ValueError(
1323
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
1324
+ )
1325
+ kwargs["token"] = use_auth_token
1326
+
1327
+ kwargs["_from_auto"] = True
1328
+ kwargs["name_or_path"] = pretrained_model_name_or_path
1329
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
1330
+ code_revision = kwargs.pop("code_revision", None)
1331
+
1332
+ config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
1333
+ has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]
1334
+ has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING
1335
+ if has_remote_code:
1336
+ class_ref = config_dict["auto_map"]["AutoConfig"]
1337
+ if "--" in class_ref:
1338
+ upstream_repo = class_ref.split("--")[0]
1339
+ else:
1340
+ upstream_repo = None
1341
+ trust_remote_code = resolve_trust_remote_code(
1342
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
1343
+ )
1344
+
1345
+ if has_remote_code and trust_remote_code:
1346
+ config_class = get_class_from_dynamic_module(
1347
+ class_ref, pretrained_model_name_or_path, code_revision=code_revision, **kwargs
1348
+ )
1349
+ config_class.register_for_auto_class()
1350
+ return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
1351
+ elif "model_type" in config_dict:
1352
+ # Apply heuristic: if model_type is mistral but layer_types is present, treat as ministral
1353
+ if config_dict["model_type"] == "mistral" and "layer_types" in config_dict:
1354
+ logger.info(
1355
+ "Detected mistral model with layer_types, treating as ministral for alternating attention compatibility. "
1356
+ )
1357
+ config_dict["model_type"] = "ministral"
1358
+
1359
+ try:
1360
+ config_class = CONFIG_MAPPING[config_dict["model_type"]]
1361
+ except KeyError:
1362
+ raise ValueError(
1363
+ f"The checkpoint you are trying to load has model type `{config_dict['model_type']}` "
1364
+ "but Transformers does not recognize this architecture. This could be because of an "
1365
+ "issue with the checkpoint, or because your version of Transformers is out of date.\n\n"
1366
+ "You can update Transformers with the command `pip install --upgrade transformers`. If this "
1367
+ "does not work, and the checkpoint is very new, then there may not be a release version "
1368
+ "that supports this model yet. In this case, you can get the most up-to-date code by installing "
1369
+ "Transformers from source with the command "
1370
+ "`pip install git+https://github.com/huggingface/transformers.git`"
1371
+ )
1372
+ return config_class.from_dict(config_dict, **unused_kwargs)
1373
+ else:
1374
+ # Fallback: use pattern matching on the string.
1375
+ # We go from longer names to shorter names to catch roberta before bert (for instance)
1376
+ for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True):
1377
+ if pattern in str(pretrained_model_name_or_path):
1378
+ return CONFIG_MAPPING[pattern].from_dict(config_dict, **unused_kwargs)
1379
+
1380
+ raise ValueError(
1381
+ f"Unrecognized model in {pretrained_model_name_or_path}. "
1382
+ f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
1383
+ f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
1384
+ )
1385
+
1386
+ @staticmethod
1387
+ def register(model_type, config, exist_ok=False) -> None:
1388
+ """
1389
+ Register a new configuration for this class.
1390
+
1391
+ Args:
1392
+ model_type (`str`): The model type like "bert" or "gpt".
1393
+ config ([`PretrainedConfig`]): The config to register.
1394
+ """
1395
+ if issubclass(config, PretrainedConfig) and config.model_type != model_type:
1396
+ raise ValueError(
1397
+ "The config you are passing has a `model_type` attribute that is not consistent with the model type "
1398
+ f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
1399
+ "match!"
1400
+ )
1401
+ CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok)
1402
+
1403
+
1404
+ __all__ = ["CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/feature_extraction_auto.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """AutoFeatureExtractor class."""
16
+
17
+ import importlib
18
+ import json
19
+ import os
20
+ import warnings
21
+ from collections import OrderedDict
22
+ from typing import Optional, Union
23
+
24
+ # Build the list of all feature extractors
25
+ from ...configuration_utils import PretrainedConfig
26
+ from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
27
+ from ...feature_extraction_utils import FeatureExtractionMixin
28
+ from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, cached_file, logging
29
+ from .auto_factory import _LazyAutoMapping
30
+ from .configuration_auto import (
31
+ CONFIG_MAPPING_NAMES,
32
+ AutoConfig,
33
+ model_type_to_module_name,
34
+ replace_list_option_in_docstrings,
35
+ )
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
41
+ [
42
+ ("audio-spectrogram-transformer", "ASTFeatureExtractor"),
43
+ ("beit", "BeitFeatureExtractor"),
44
+ ("chinese_clip", "ChineseCLIPFeatureExtractor"),
45
+ ("clap", "ClapFeatureExtractor"),
46
+ ("clip", "CLIPFeatureExtractor"),
47
+ ("clipseg", "ViTFeatureExtractor"),
48
+ ("clvp", "ClvpFeatureExtractor"),
49
+ ("conditional_detr", "ConditionalDetrFeatureExtractor"),
50
+ ("convnext", "ConvNextFeatureExtractor"),
51
+ ("cvt", "ConvNextFeatureExtractor"),
52
+ ("dac", "DacFeatureExtractor"),
53
+ ("data2vec-audio", "Wav2Vec2FeatureExtractor"),
54
+ ("data2vec-vision", "BeitFeatureExtractor"),
55
+ ("deformable_detr", "DeformableDetrFeatureExtractor"),
56
+ ("deit", "DeiTFeatureExtractor"),
57
+ ("detr", "DetrFeatureExtractor"),
58
+ ("dia", "DiaFeatureExtractor"),
59
+ ("dinat", "ViTFeatureExtractor"),
60
+ ("donut-swin", "DonutFeatureExtractor"),
61
+ ("dpt", "DPTFeatureExtractor"),
62
+ ("encodec", "EncodecFeatureExtractor"),
63
+ ("flava", "FlavaFeatureExtractor"),
64
+ ("gemma3n", "Gemma3nAudioFeatureExtractor"),
65
+ ("glpn", "GLPNFeatureExtractor"),
66
+ ("granite_speech", "GraniteSpeechFeatureExtractor"),
67
+ ("groupvit", "CLIPFeatureExtractor"),
68
+ ("hubert", "Wav2Vec2FeatureExtractor"),
69
+ ("imagegpt", "ImageGPTFeatureExtractor"),
70
+ ("kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"),
71
+ ("layoutlmv2", "LayoutLMv2FeatureExtractor"),
72
+ ("layoutlmv3", "LayoutLMv3FeatureExtractor"),
73
+ ("levit", "LevitFeatureExtractor"),
74
+ ("maskformer", "MaskFormerFeatureExtractor"),
75
+ ("mctct", "MCTCTFeatureExtractor"),
76
+ ("mimi", "EncodecFeatureExtractor"),
77
+ ("mobilenet_v1", "MobileNetV1FeatureExtractor"),
78
+ ("mobilenet_v2", "MobileNetV2FeatureExtractor"),
79
+ ("mobilevit", "MobileViTFeatureExtractor"),
80
+ ("moonshine", "Wav2Vec2FeatureExtractor"),
81
+ ("moshi", "EncodecFeatureExtractor"),
82
+ ("nat", "ViTFeatureExtractor"),
83
+ ("owlvit", "OwlViTFeatureExtractor"),
84
+ ("parakeet_ctc", "ParakeetFeatureExtractor"),
85
+ ("parakeet_encoder", "ParakeetFeatureExtractor"),
86
+ ("perceiver", "PerceiverFeatureExtractor"),
87
+ ("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
88
+ ("poolformer", "PoolFormerFeatureExtractor"),
89
+ ("pop2piano", "Pop2PianoFeatureExtractor"),
90
+ ("regnet", "ConvNextFeatureExtractor"),
91
+ ("resnet", "ConvNextFeatureExtractor"),
92
+ ("seamless_m4t", "SeamlessM4TFeatureExtractor"),
93
+ ("seamless_m4t_v2", "SeamlessM4TFeatureExtractor"),
94
+ ("segformer", "SegformerFeatureExtractor"),
95
+ ("sew", "Wav2Vec2FeatureExtractor"),
96
+ ("sew-d", "Wav2Vec2FeatureExtractor"),
97
+ ("speech_to_text", "Speech2TextFeatureExtractor"),
98
+ ("speecht5", "SpeechT5FeatureExtractor"),
99
+ ("swiftformer", "ViTFeatureExtractor"),
100
+ ("swin", "ViTFeatureExtractor"),
101
+ ("swinv2", "ViTFeatureExtractor"),
102
+ ("table-transformer", "DetrFeatureExtractor"),
103
+ ("timesformer", "VideoMAEFeatureExtractor"),
104
+ ("tvlt", "TvltFeatureExtractor"),
105
+ ("unispeech", "Wav2Vec2FeatureExtractor"),
106
+ ("unispeech-sat", "Wav2Vec2FeatureExtractor"),
107
+ ("univnet", "UnivNetFeatureExtractor"),
108
+ ("van", "ConvNextFeatureExtractor"),
109
+ ("videomae", "VideoMAEFeatureExtractor"),
110
+ ("vilt", "ViltFeatureExtractor"),
111
+ ("vit", "ViTFeatureExtractor"),
112
+ ("vit_mae", "ViTFeatureExtractor"),
113
+ ("vit_msn", "ViTFeatureExtractor"),
114
+ ("wav2vec2", "Wav2Vec2FeatureExtractor"),
115
+ ("wav2vec2-bert", "Wav2Vec2FeatureExtractor"),
116
+ ("wav2vec2-conformer", "Wav2Vec2FeatureExtractor"),
117
+ ("wavlm", "Wav2Vec2FeatureExtractor"),
118
+ ("whisper", "WhisperFeatureExtractor"),
119
+ ("xclip", "CLIPFeatureExtractor"),
120
+ ("xcodec", "DacFeatureExtractor"),
121
+ ("yolos", "YolosFeatureExtractor"),
122
+ ]
123
+ )
124
+
125
+ FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
126
+
127
+
128
+ def feature_extractor_class_from_name(class_name: str):
129
+ for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items():
130
+ if class_name in extractors:
131
+ module_name = model_type_to_module_name(module_name)
132
+
133
+ module = importlib.import_module(f".{module_name}", "transformers.models")
134
+ try:
135
+ return getattr(module, class_name)
136
+ except AttributeError:
137
+ continue
138
+
139
+ for extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.values():
140
+ if getattr(extractor, "__name__", None) == class_name:
141
+ return extractor
142
+
143
+ # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
144
+ # init and we return the proper dummy to get an appropriate error message.
145
+ main_module = importlib.import_module("transformers")
146
+ if hasattr(main_module, class_name):
147
+ return getattr(main_module, class_name)
148
+
149
+ return None
150
+
151
+
152
+ def get_feature_extractor_config(
153
+ pretrained_model_name_or_path: Union[str, os.PathLike],
154
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
155
+ force_download: bool = False,
156
+ resume_download: Optional[bool] = None,
157
+ proxies: Optional[dict[str, str]] = None,
158
+ token: Optional[Union[bool, str]] = None,
159
+ revision: Optional[str] = None,
160
+ local_files_only: bool = False,
161
+ **kwargs,
162
+ ):
163
+ """
164
+ Loads the tokenizer configuration from a pretrained model tokenizer configuration.
165
+
166
+ Args:
167
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
168
+ This can be either:
169
+
170
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
171
+ huggingface.co.
172
+ - a path to a *directory* containing a configuration file saved using the
173
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
174
+
175
+ cache_dir (`str` or `os.PathLike`, *optional*):
176
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
177
+ cache should not be used.
178
+ force_download (`bool`, *optional*, defaults to `False`):
179
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
180
+ exist.
181
+ resume_download:
182
+ Deprecated and ignored. All downloads are now resumed by default when possible.
183
+ Will be removed in v5 of Transformers.
184
+ proxies (`dict[str, str]`, *optional*):
185
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
186
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
187
+ token (`str` or *bool*, *optional*):
188
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
189
+ when running `hf auth login` (stored in `~/.huggingface`).
190
+ revision (`str`, *optional*, defaults to `"main"`):
191
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
192
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
193
+ identifier allowed by git.
194
+ local_files_only (`bool`, *optional*, defaults to `False`):
195
+ If `True`, will only try to load the tokenizer configuration from local files.
196
+
197
+ <Tip>
198
+
199
+ Passing `token=True` is required when you want to use a private model.
200
+
201
+ </Tip>
202
+
203
+ Returns:
204
+ `Dict`: The configuration of the tokenizer.
205
+
206
+ Examples:
207
+
208
+ ```python
209
+ # Download configuration from huggingface.co and cache.
210
+ tokenizer_config = get_tokenizer_config("google-bert/bert-base-uncased")
211
+ # This model does not have a tokenizer config so the result will be an empty dict.
212
+ tokenizer_config = get_tokenizer_config("FacebookAI/xlm-roberta-base")
213
+
214
+ # Save a pretrained tokenizer locally and you can reload its config
215
+ from transformers import AutoTokenizer
216
+
217
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
218
+ tokenizer.save_pretrained("tokenizer-test")
219
+ tokenizer_config = get_tokenizer_config("tokenizer-test")
220
+ ```"""
221
+ use_auth_token = kwargs.pop("use_auth_token", None)
222
+ if use_auth_token is not None:
223
+ warnings.warn(
224
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
225
+ FutureWarning,
226
+ )
227
+ if token is not None:
228
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
229
+ token = use_auth_token
230
+
231
+ resolved_config_file = cached_file(
232
+ pretrained_model_name_or_path,
233
+ FEATURE_EXTRACTOR_NAME,
234
+ cache_dir=cache_dir,
235
+ force_download=force_download,
236
+ resume_download=resume_download,
237
+ proxies=proxies,
238
+ token=token,
239
+ revision=revision,
240
+ local_files_only=local_files_only,
241
+ _raise_exceptions_for_gated_repo=False,
242
+ _raise_exceptions_for_missing_entries=False,
243
+ _raise_exceptions_for_connection_errors=False,
244
+ )
245
+ if resolved_config_file is None:
246
+ logger.info(
247
+ "Could not locate the feature extractor configuration file, will try to use the model config instead."
248
+ )
249
+ return {}
250
+
251
+ with open(resolved_config_file, encoding="utf-8") as reader:
252
+ return json.load(reader)
253
+
254
+
255
+ class AutoFeatureExtractor:
256
+ r"""
257
+ This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the
258
+ library when created with the [`AutoFeatureExtractor.from_pretrained`] class method.
259
+
260
+ This class cannot be instantiated directly using `__init__()` (throws an error).
261
+ """
262
+
263
+ def __init__(self):
264
+ raise OSError(
265
+ "AutoFeatureExtractor is designed to be instantiated "
266
+ "using the `AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)` method."
267
+ )
268
+
269
+ @classmethod
270
+ @replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING_NAMES)
271
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
272
+ r"""
273
+ Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary.
274
+
275
+ The feature extractor class to instantiate is selected based on the `model_type` property of the config object
276
+ (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
277
+ missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
278
+
279
+ List options
280
+
281
+ Params:
282
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
283
+ This can be either:
284
+
285
+ - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
286
+ huggingface.co.
287
+ - a path to a *directory* containing a feature extractor file saved using the
288
+ [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g.,
289
+ `./my_model_directory/`.
290
+ - a path or url to a saved feature extractor JSON *file*, e.g.,
291
+ `./my_model_directory/preprocessor_config.json`.
292
+ cache_dir (`str` or `os.PathLike`, *optional*):
293
+ Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
294
+ standard cache should not be used.
295
+ force_download (`bool`, *optional*, defaults to `False`):
296
+ Whether or not to force to (re-)download the feature extractor files and override the cached versions
297
+ if they exist.
298
+ resume_download:
299
+ Deprecated and ignored. All downloads are now resumed by default when possible.
300
+ Will be removed in v5 of Transformers.
301
+ proxies (`dict[str, str]`, *optional*):
302
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
303
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
304
+ token (`str` or *bool*, *optional*):
305
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
306
+ when running `hf auth login` (stored in `~/.huggingface`).
307
+ revision (`str`, *optional*, defaults to `"main"`):
308
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
309
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
310
+ identifier allowed by git.
311
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
312
+ If `False`, then this function returns just the final feature extractor object. If `True`, then this
313
+ functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
314
+ consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
315
+ `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
316
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
317
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
318
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
319
+ execute code present on the Hub on your local machine.
320
+ kwargs (`dict[str, Any]`, *optional*):
321
+ The values in kwargs of any keys which are feature extractor attributes will be used to override the
322
+ loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
323
+ controlled by the `return_unused_kwargs` keyword parameter.
324
+
325
+ <Tip>
326
+
327
+ Passing `token=True` is required when you want to use a private model.
328
+
329
+ </Tip>
330
+
331
+ Examples:
332
+
333
+ ```python
334
+ >>> from transformers import AutoFeatureExtractor
335
+
336
+ >>> # Download feature extractor from huggingface.co and cache.
337
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
338
+
339
+ >>> # If feature extractor files are in a directory (e.g. feature extractor was saved using *save_pretrained('./test/saved_model/')*)
340
+ >>> # feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
341
+ ```"""
342
+ use_auth_token = kwargs.pop("use_auth_token", None)
343
+ if use_auth_token is not None:
344
+ warnings.warn(
345
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
346
+ FutureWarning,
347
+ )
348
+ if kwargs.get("token") is not None:
349
+ raise ValueError(
350
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
351
+ )
352
+ kwargs["token"] = use_auth_token
353
+
354
+ config = kwargs.pop("config", None)
355
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
356
+ kwargs["_from_auto"] = True
357
+
358
+ config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
359
+ feature_extractor_class = config_dict.get("feature_extractor_type", None)
360
+ feature_extractor_auto_map = None
361
+ if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
362
+ feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
363
+
364
+ # If we don't find the feature extractor class in the feature extractor config, let's try the model config.
365
+ if feature_extractor_class is None and feature_extractor_auto_map is None:
366
+ if not isinstance(config, PretrainedConfig):
367
+ config = AutoConfig.from_pretrained(
368
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
369
+ )
370
+ # It could be in `config.feature_extractor_type``
371
+ feature_extractor_class = getattr(config, "feature_extractor_type", None)
372
+ if hasattr(config, "auto_map") and "AutoFeatureExtractor" in config.auto_map:
373
+ feature_extractor_auto_map = config.auto_map["AutoFeatureExtractor"]
374
+
375
+ if feature_extractor_class is not None:
376
+ feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)
377
+
378
+ has_remote_code = feature_extractor_auto_map is not None
379
+ has_local_code = feature_extractor_class is not None or type(config) in FEATURE_EXTRACTOR_MAPPING
380
+ if has_remote_code:
381
+ if "--" in feature_extractor_auto_map:
382
+ upstream_repo = feature_extractor_auto_map.split("--")[0]
383
+ else:
384
+ upstream_repo = None
385
+ trust_remote_code = resolve_trust_remote_code(
386
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
387
+ )
388
+
389
+ if has_remote_code and trust_remote_code:
390
+ feature_extractor_class = get_class_from_dynamic_module(
391
+ feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs
392
+ )
393
+ _ = kwargs.pop("code_revision", None)
394
+ feature_extractor_class.register_for_auto_class()
395
+ return feature_extractor_class.from_dict(config_dict, **kwargs)
396
+ elif feature_extractor_class is not None:
397
+ return feature_extractor_class.from_dict(config_dict, **kwargs)
398
+ # Last try: we use the FEATURE_EXTRACTOR_MAPPING.
399
+ elif type(config) in FEATURE_EXTRACTOR_MAPPING:
400
+ feature_extractor_class = FEATURE_EXTRACTOR_MAPPING[type(config)]
401
+ return feature_extractor_class.from_dict(config_dict, **kwargs)
402
+
403
+ raise ValueError(
404
+ f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a "
405
+ f"`feature_extractor_type` key in its {FEATURE_EXTRACTOR_NAME} of {CONFIG_NAME}, or one of the following "
406
+ f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES)}"
407
+ )
408
+
409
+ @staticmethod
410
+ def register(config_class, feature_extractor_class, exist_ok=False):
411
+ """
412
+ Register a new feature extractor for this class.
413
+
414
+ Args:
415
+ config_class ([`PretrainedConfig`]):
416
+ The configuration corresponding to the model to register.
417
+ feature_extractor_class ([`FeatureExtractorMixin`]): The feature extractor to register.
418
+ """
419
+ FEATURE_EXTRACTOR_MAPPING.register(config_class, feature_extractor_class, exist_ok=exist_ok)
420
+
421
+
422
+ __all__ = ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/image_processing_auto.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """AutoImageProcessor class."""
16
+
17
+ import importlib
18
+ import json
19
+ import os
20
+ import warnings
21
+ from collections import OrderedDict
22
+ from typing import TYPE_CHECKING, Optional, Union
23
+
24
+ # Build the list of all image processors
25
+ from ...configuration_utils import PretrainedConfig
26
+ from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
27
+ from ...image_processing_utils import ImageProcessingMixin
28
+ from ...image_processing_utils_fast import BaseImageProcessorFast
29
+ from ...utils import (
30
+ CONFIG_NAME,
31
+ IMAGE_PROCESSOR_NAME,
32
+ cached_file,
33
+ is_timm_config_dict,
34
+ is_timm_local_checkpoint,
35
+ is_torchvision_available,
36
+ is_vision_available,
37
+ logging,
38
+ )
39
+ from ...utils.import_utils import requires
40
+ from .auto_factory import _LazyAutoMapping
41
+ from .configuration_auto import (
42
+ CONFIG_MAPPING_NAMES,
43
+ AutoConfig,
44
+ model_type_to_module_name,
45
+ replace_list_option_in_docstrings,
46
+ )
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ FORCE_FAST_IMAGE_PROCESSOR = ["Qwen2VLImageProcessor"]
53
+
54
+
55
+ if TYPE_CHECKING:
56
+ # This significantly improves completion suggestion performance when
57
+ # the transformers package is used with Microsoft's Pylance language server.
58
+ IMAGE_PROCESSOR_MAPPING_NAMES: OrderedDict[str, tuple[Optional[str], Optional[str]]] = OrderedDict()
59
+ else:
60
+ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
61
+ [
62
+ ("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
63
+ ("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
64
+ ("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
65
+ ("aria", ("AriaImageProcessor", None)),
66
+ ("beit", ("BeitImageProcessor", "BeitImageProcessorFast")),
67
+ ("bit", ("BitImageProcessor", "BitImageProcessorFast")),
68
+ ("blip", ("BlipImageProcessor", "BlipImageProcessorFast")),
69
+ ("blip-2", ("BlipImageProcessor", "BlipImageProcessorFast")),
70
+ ("bridgetower", ("BridgeTowerImageProcessor", "BridgeTowerImageProcessorFast")),
71
+ ("chameleon", ("ChameleonImageProcessor", "ChameleonImageProcessorFast")),
72
+ ("chinese_clip", ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast")),
73
+ ("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
74
+ ("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")),
75
+ ("cohere2_vision", (None, "Cohere2VisionImageProcessorFast")),
76
+ ("conditional_detr", ("ConditionalDetrImageProcessor", "ConditionalDetrImageProcessorFast")),
77
+ ("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
78
+ ("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
79
+ ("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
80
+ ("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")),
81
+ ("deepseek_vl", ("DeepseekVLImageProcessor", "DeepseekVLImageProcessorFast")),
82
+ ("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor", "DeepseekVLHybridImageProcessorFast")),
83
+ ("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
84
+ ("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")),
85
+ ("depth_anything", ("DPTImageProcessor", "DPTImageProcessorFast")),
86
+ ("depth_pro", ("DepthProImageProcessor", "DepthProImageProcessorFast")),
87
+ ("deta", ("DetaImageProcessor", None)),
88
+ ("detr", ("DetrImageProcessor", "DetrImageProcessorFast")),
89
+ ("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")),
90
+ ("dinov2", ("BitImageProcessor", "BitImageProcessorFast")),
91
+ ("dinov3_vit", (None, "DINOv3ViTImageProcessorFast")),
92
+ ("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")),
93
+ ("dpt", ("DPTImageProcessor", "DPTImageProcessorFast")),
94
+ ("edgetam", (None, "Sam2ImageProcessorFast")),
95
+ ("efficientformer", ("EfficientFormerImageProcessor", None)),
96
+ ("efficientloftr", ("EfficientLoFTRImageProcessor", "EfficientLoFTRImageProcessorFast")),
97
+ ("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
98
+ ("eomt", ("EomtImageProcessor", "EomtImageProcessorFast")),
99
+ ("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")),
100
+ ("focalnet", ("BitImageProcessor", "BitImageProcessorFast")),
101
+ ("fuyu", ("FuyuImageProcessor", None)),
102
+ ("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
103
+ ("gemma3n", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
104
+ ("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
105
+ ("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")),
106
+ ("glpn", ("GLPNImageProcessor", None)),
107
+ ("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
108
+ ("grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
109
+ ("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
110
+ ("hiera", ("BitImageProcessor", "BitImageProcessorFast")),
111
+ ("idefics", ("IdeficsImageProcessor", None)),
112
+ ("idefics2", ("Idefics2ImageProcessor", "Idefics2ImageProcessorFast")),
113
+ ("idefics3", ("Idefics3ImageProcessor", "Idefics3ImageProcessorFast")),
114
+ ("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
115
+ ("imagegpt", ("ImageGPTImageProcessor", "ImageGPTImageProcessorFast")),
116
+ ("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
117
+ ("instructblipvideo", ("InstructBlipVideoImageProcessor", None)),
118
+ ("janus", ("JanusImageProcessor", "JanusImageProcessorFast")),
119
+ ("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
120
+ ("kosmos-2.5", ("Kosmos2_5ImageProcessor", "Kosmos2_5ImageProcessorFast")),
121
+ ("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")),
122
+ ("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
123
+ ("levit", ("LevitImageProcessor", "LevitImageProcessorFast")),
124
+ ("lfm2_vl", (None, "Lfm2VlImageProcessorFast")),
125
+ ("lightglue", ("LightGlueImageProcessor", None)),
126
+ ("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")),
127
+ ("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
128
+ ("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
129
+ ("llava_next_video", ("LlavaNextVideoImageProcessor", None)),
130
+ ("llava_onevision", ("LlavaOnevisionImageProcessor", "LlavaOnevisionImageProcessorFast")),
131
+ ("mask2former", ("Mask2FormerImageProcessor", "Mask2FormerImageProcessorFast")),
132
+ ("maskformer", ("MaskFormerImageProcessor", "MaskFormerImageProcessorFast")),
133
+ ("metaclip_2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
134
+ ("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
135
+ ("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
136
+ ("mlcd", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
137
+ ("mllama", ("MllamaImageProcessor", None)),
138
+ ("mm-grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
139
+ ("mobilenet_v1", ("MobileNetV1ImageProcessor", "MobileNetV1ImageProcessorFast")),
140
+ ("mobilenet_v2", ("MobileNetV2ImageProcessor", "MobileNetV2ImageProcessorFast")),
141
+ ("mobilevit", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")),
142
+ ("mobilevitv2", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")),
143
+ ("nat", ("ViTImageProcessor", "ViTImageProcessorFast")),
144
+ ("nougat", ("NougatImageProcessor", "NougatImageProcessorFast")),
145
+ ("oneformer", ("OneFormerImageProcessor", "OneFormerImageProcessorFast")),
146
+ ("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")),
147
+ ("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")),
148
+ ("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
149
+ ("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
150
+ ("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")),
151
+ ("perception_lm", (None, "PerceptionLMImageProcessorFast")),
152
+ ("phi4_multimodal", (None, "Phi4MultimodalImageProcessorFast")),
153
+ ("pix2struct", ("Pix2StructImageProcessor", None)),
154
+ ("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
155
+ ("poolformer", ("PoolFormerImageProcessor", "PoolFormerImageProcessorFast")),
156
+ ("prompt_depth_anything", ("PromptDepthAnythingImageProcessor", "PromptDepthAnythingImageProcessorFast")),
157
+ ("pvt", ("PvtImageProcessor", "PvtImageProcessorFast")),
158
+ ("pvt_v2", ("PvtImageProcessor", "PvtImageProcessorFast")),
159
+ ("qwen2_5_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
160
+ ("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
161
+ ("qwen3_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
162
+ ("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
163
+ ("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
164
+ ("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
165
+ ("sam", ("SamImageProcessor", "SamImageProcessorFast")),
166
+ ("sam2", (None, "Sam2ImageProcessorFast")),
167
+ ("sam_hq", ("SamImageProcessor", "SamImageProcessorFast")),
168
+ ("segformer", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
169
+ ("seggpt", ("SegGptImageProcessor", None)),
170
+ ("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
171
+ ("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
172
+ ("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")),
173
+ ("smolvlm", ("SmolVLMImageProcessor", "SmolVLMImageProcessorFast")),
174
+ ("superglue", ("SuperGlueImageProcessor", None)),
175
+ ("superpoint", ("SuperPointImageProcessor", "SuperPointImageProcessorFast")),
176
+ ("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
177
+ ("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
178
+ ("swin2sr", ("Swin2SRImageProcessor", "Swin2SRImageProcessorFast")),
179
+ ("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")),
180
+ ("table-transformer", ("DetrImageProcessor", "DetrImageProcessorFast")),
181
+ ("textnet", ("TextNetImageProcessor", "TextNetImageProcessorFast")),
182
+ ("timesformer", ("VideoMAEImageProcessor", None)),
183
+ ("timm_wrapper", ("TimmWrapperImageProcessor", None)),
184
+ ("tvlt", ("TvltImageProcessor", None)),
185
+ ("tvp", ("TvpImageProcessor", "TvpImageProcessorFast")),
186
+ ("udop", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
187
+ ("upernet", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
188
+ ("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
189
+ ("videomae", ("VideoMAEImageProcessor", None)),
190
+ ("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")),
191
+ ("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
192
+ ("vit", ("ViTImageProcessor", "ViTImageProcessorFast")),
193
+ ("vit_hybrid", ("ViTHybridImageProcessor", None)),
194
+ ("vit_mae", ("ViTImageProcessor", "ViTImageProcessorFast")),
195
+ ("vit_msn", ("ViTImageProcessor", "ViTImageProcessorFast")),
196
+ ("vitmatte", ("VitMatteImageProcessor", "VitMatteImageProcessorFast")),
197
+ ("xclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
198
+ ("yolos", ("YolosImageProcessor", "YolosImageProcessorFast")),
199
+ ("zoedepth", ("ZoeDepthImageProcessor", "ZoeDepthImageProcessorFast")),
200
+ ]
201
+ )
202
+
203
+ # Override to None if the packages are not available
204
+ for model_type, (slow_class, fast_class) in IMAGE_PROCESSOR_MAPPING_NAMES.items():
205
+ if not is_vision_available():
206
+ slow_class = None
207
+ if not is_torchvision_available():
208
+ fast_class = None
209
+
210
+ IMAGE_PROCESSOR_MAPPING_NAMES[model_type] = (slow_class, fast_class)
211
+
212
+ IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
213
+
214
+
215
+ def get_image_processor_class_from_name(class_name: str):
216
+ if class_name == "BaseImageProcessorFast":
217
+ return BaseImageProcessorFast
218
+
219
+ for module_name, extractors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
220
+ if class_name in extractors:
221
+ module_name = model_type_to_module_name(module_name)
222
+
223
+ module = importlib.import_module(f".{module_name}", "transformers.models")
224
+ try:
225
+ return getattr(module, class_name)
226
+ except AttributeError:
227
+ continue
228
+
229
+ for extractors in IMAGE_PROCESSOR_MAPPING._extra_content.values():
230
+ for extractor in extractors:
231
+ if getattr(extractor, "__name__", None) == class_name:
232
+ return extractor
233
+
234
+ # We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main
235
+ # init and we return the proper dummy to get an appropriate error message.
236
+ main_module = importlib.import_module("transformers")
237
+ if hasattr(main_module, class_name):
238
+ return getattr(main_module, class_name)
239
+
240
+ return None
241
+
242
+
243
+ def get_image_processor_config(
244
+ pretrained_model_name_or_path: Union[str, os.PathLike],
245
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
246
+ force_download: bool = False,
247
+ resume_download: Optional[bool] = None,
248
+ proxies: Optional[dict[str, str]] = None,
249
+ token: Optional[Union[bool, str]] = None,
250
+ revision: Optional[str] = None,
251
+ local_files_only: bool = False,
252
+ **kwargs,
253
+ ):
254
+ """
255
+ Loads the image processor configuration from a pretrained model image processor configuration.
256
+
257
+ Args:
258
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
259
+ This can be either:
260
+
261
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
262
+ huggingface.co.
263
+ - a path to a *directory* containing a configuration file saved using the
264
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
265
+
266
+ cache_dir (`str` or `os.PathLike`, *optional*):
267
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
268
+ cache should not be used.
269
+ force_download (`bool`, *optional*, defaults to `False`):
270
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
271
+ exist.
272
+ resume_download:
273
+ Deprecated and ignored. All downloads are now resumed by default when possible.
274
+ Will be removed in v5 of Transformers.
275
+ proxies (`dict[str, str]`, *optional*):
276
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
277
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
278
+ token (`str` or *bool*, *optional*):
279
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
280
+ when running `hf auth login` (stored in `~/.huggingface`).
281
+ revision (`str`, *optional*, defaults to `"main"`):
282
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
283
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
284
+ identifier allowed by git.
285
+ local_files_only (`bool`, *optional*, defaults to `False`):
286
+ If `True`, will only try to load the image processor configuration from local files.
287
+
288
+ <Tip>
289
+
290
+ Passing `token=True` is required when you want to use a private model.
291
+
292
+ </Tip>
293
+
294
+ Returns:
295
+ `Dict`: The configuration of the image processor.
296
+
297
+ Examples:
298
+
299
+ ```python
300
+ # Download configuration from huggingface.co and cache.
301
+ image_processor_config = get_image_processor_config("google-bert/bert-base-uncased")
302
+ # This model does not have a image processor config so the result will be an empty dict.
303
+ image_processor_config = get_image_processor_config("FacebookAI/xlm-roberta-base")
304
+
305
+ # Save a pretrained image processor locally and you can reload its config
306
+ from transformers import AutoTokenizer
307
+
308
+ image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
309
+ image_processor.save_pretrained("image-processor-test")
310
+ image_processor_config = get_image_processor_config("image-processor-test")
311
+ ```"""
312
+ use_auth_token = kwargs.pop("use_auth_token", None)
313
+ if use_auth_token is not None:
314
+ warnings.warn(
315
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
316
+ FutureWarning,
317
+ )
318
+ if token is not None:
319
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
320
+ token = use_auth_token
321
+
322
+ resolved_config_file = cached_file(
323
+ pretrained_model_name_or_path,
324
+ IMAGE_PROCESSOR_NAME,
325
+ cache_dir=cache_dir,
326
+ force_download=force_download,
327
+ resume_download=resume_download,
328
+ proxies=proxies,
329
+ token=token,
330
+ revision=revision,
331
+ local_files_only=local_files_only,
332
+ _raise_exceptions_for_gated_repo=False,
333
+ _raise_exceptions_for_missing_entries=False,
334
+ _raise_exceptions_for_connection_errors=False,
335
+ )
336
+ if resolved_config_file is None:
337
+ logger.info(
338
+ "Could not locate the image processor configuration file, will try to use the model config instead."
339
+ )
340
+ return {}
341
+
342
+ with open(resolved_config_file, encoding="utf-8") as reader:
343
+ return json.load(reader)
344
+
345
+
346
+ def _warning_fast_image_processor_available(fast_class):
347
+ logger.warning(
348
+ f"Fast image processor class {fast_class} is available for this model. "
349
+ "Using slow image processor class. To use the fast image processor class set `use_fast=True`."
350
+ )
351
+
352
+
353
+ @requires(backends=("vision",))
354
+ class AutoImageProcessor:
355
+ r"""
356
+ This is a generic image processor class that will be instantiated as one of the image processor classes of the
357
+ library when created with the [`AutoImageProcessor.from_pretrained`] class method.
358
+
359
+ This class cannot be instantiated directly using `__init__()` (throws an error).
360
+ """
361
+
362
+ def __init__(self):
363
+ raise OSError(
364
+ "AutoImageProcessor is designed to be instantiated "
365
+ "using the `AutoImageProcessor.from_pretrained(pretrained_model_name_or_path)` method."
366
+ )
367
+
368
+ @classmethod
369
+ @replace_list_option_in_docstrings(IMAGE_PROCESSOR_MAPPING_NAMES)
370
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
371
+ r"""
372
+ Instantiate one of the image processor classes of the library from a pretrained model vocabulary.
373
+
374
+ The image processor class to instantiate is selected based on the `model_type` property of the config object
375
+ (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
376
+ missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
377
+
378
+ List options
379
+
380
+ Params:
381
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
382
+ This can be either:
383
+
384
+ - a string, the *model id* of a pretrained image_processor hosted inside a model repo on
385
+ huggingface.co.
386
+ - a path to a *directory* containing a image processor file saved using the
387
+ [`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g.,
388
+ `./my_model_directory/`.
389
+ - a path or url to a saved image processor JSON *file*, e.g.,
390
+ `./my_model_directory/preprocessor_config.json`.
391
+ cache_dir (`str` or `os.PathLike`, *optional*):
392
+ Path to a directory in which a downloaded pretrained model image processor should be cached if the
393
+ standard cache should not be used.
394
+ force_download (`bool`, *optional*, defaults to `False`):
395
+ Whether or not to force to (re-)download the image processor files and override the cached versions if
396
+ they exist.
397
+ resume_download:
398
+ Deprecated and ignored. All downloads are now resumed by default when possible.
399
+ Will be removed in v5 of Transformers.
400
+ proxies (`dict[str, str]`, *optional*):
401
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
402
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
403
+ token (`str` or *bool*, *optional*):
404
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
405
+ when running `hf auth login` (stored in `~/.huggingface`).
406
+ revision (`str`, *optional*, defaults to `"main"`):
407
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
408
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
409
+ identifier allowed by git.
410
+ use_fast (`bool`, *optional*, defaults to `False`):
411
+ Use a fast torchvision-base image processor if it is supported for a given model.
412
+ If a fast image processor is not available for a given model, a normal numpy-based image processor
413
+ is returned instead.
414
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
415
+ If `False`, then this function returns just the final image processor object. If `True`, then this
416
+ functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
417
+ consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
418
+ `kwargs` which has not been used to update `image_processor` and is otherwise ignored.
419
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
420
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
421
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
422
+ execute code present on the Hub on your local machine.
423
+ image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
424
+ The name of the file in the model directory to use for the image processor config.
425
+ kwargs (`dict[str, Any]`, *optional*):
426
+ The values in kwargs of any keys which are image processor attributes will be used to override the
427
+ loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
428
+ controlled by the `return_unused_kwargs` keyword parameter.
429
+
430
+ <Tip>
431
+
432
+ Passing `token=True` is required when you want to use a private model.
433
+
434
+ </Tip>
435
+
436
+ Examples:
437
+
438
+ ```python
439
+ >>> from transformers import AutoImageProcessor
440
+
441
+ >>> # Download image processor from huggingface.co and cache.
442
+ >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
443
+
444
+ >>> # If image processor files are in a directory (e.g. image processor was saved using *save_pretrained('./test/saved_model/')*)
445
+ >>> # image_processor = AutoImageProcessor.from_pretrained("./test/saved_model/")
446
+ ```"""
447
+ use_auth_token = kwargs.pop("use_auth_token", None)
448
+ if use_auth_token is not None:
449
+ warnings.warn(
450
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
451
+ FutureWarning,
452
+ )
453
+ if kwargs.get("token") is not None:
454
+ raise ValueError(
455
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
456
+ )
457
+ kwargs["token"] = use_auth_token
458
+
459
+ config = kwargs.pop("config", None)
460
+ # TODO: @yoni, change in v4.48 (use_fast set to True by default)
461
+ use_fast = kwargs.pop("use_fast", None)
462
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
463
+ kwargs["_from_auto"] = True
464
+
465
+ # Resolve the image processor config filename
466
+ if "image_processor_filename" in kwargs:
467
+ image_processor_filename = kwargs.pop("image_processor_filename")
468
+ elif is_timm_local_checkpoint(pretrained_model_name_or_path):
469
+ image_processor_filename = CONFIG_NAME
470
+ else:
471
+ image_processor_filename = IMAGE_PROCESSOR_NAME
472
+
473
+ # Load the image processor config
474
+ try:
475
+ # Main path for all transformers models and local TimmWrapper checkpoints
476
+ config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
477
+ pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
478
+ )
479
+ except Exception as initial_exception:
480
+ # Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json`
481
+ # instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information
482
+ # except the model name, the only way to check if a remote checkpoint is a timm model is to try to
483
+ # load `config.json` and if it fails with some error, we raise the initial exception.
484
+ try:
485
+ config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
486
+ pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs
487
+ )
488
+ except Exception:
489
+ raise initial_exception
490
+
491
+ # In case we have a config_dict, but it's not a timm config dict, we raise the initial exception,
492
+ # because only timm models have image processing in `config.json`.
493
+ if not is_timm_config_dict(config_dict):
494
+ raise initial_exception
495
+
496
+ image_processor_type = config_dict.get("image_processor_type", None)
497
+ image_processor_auto_map = None
498
+ if "AutoImageProcessor" in config_dict.get("auto_map", {}):
499
+ image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
500
+
501
+ # If we still don't have the image processor class, check if we're loading from a previous feature extractor config
502
+ # and if so, infer the image processor class from there.
503
+ if image_processor_type is None and image_processor_auto_map is None:
504
+ feature_extractor_class = config_dict.pop("feature_extractor_type", None)
505
+ if feature_extractor_class is not None:
506
+ image_processor_type = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor")
507
+ if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
508
+ feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
509
+ image_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "ImageProcessor")
510
+
511
+ # If we don't find the image processor class in the image processor config, let's try the model config.
512
+ if image_processor_type is None and image_processor_auto_map is None:
513
+ if not isinstance(config, PretrainedConfig):
514
+ config = AutoConfig.from_pretrained(
515
+ pretrained_model_name_or_path,
516
+ trust_remote_code=trust_remote_code,
517
+ **kwargs,
518
+ )
519
+ # It could be in `config.image_processor_type``
520
+ image_processor_type = getattr(config, "image_processor_type", None)
521
+ if hasattr(config, "auto_map") and "AutoImageProcessor" in config.auto_map:
522
+ image_processor_auto_map = config.auto_map["AutoImageProcessor"]
523
+
524
+ image_processor_class = None
525
+ # TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
526
+ if image_processor_type is not None:
527
+ # if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor.
528
+ if use_fast is None:
529
+ use_fast = image_processor_type.endswith("Fast")
530
+ if not use_fast and image_processor_type in FORCE_FAST_IMAGE_PROCESSOR and is_torchvision_available():
531
+ use_fast = True
532
+ logger.warning_once(
533
+ f"The image processor of type `{image_processor_type}` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. "
534
+ "This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. "
535
+ "Note that this behavior will be extended to all models in a future release."
536
+ )
537
+ if not use_fast:
538
+ logger.warning_once(
539
+ "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
540
+ "`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
541
+ "This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
542
+ )
543
+ if use_fast and not image_processor_type.endswith("Fast"):
544
+ image_processor_type += "Fast"
545
+ if use_fast and not is_torchvision_available():
546
+ # check if there is a slow image processor class to fallback to
547
+ image_processor_class = get_image_processor_class_from_name(image_processor_type[:-4])
548
+ if image_processor_class is None:
549
+ raise ValueError(
550
+ f"`{image_processor_type}` requires `torchvision` to be installed. Please install `torchvision` and try again."
551
+ )
552
+ logger.warning_once(
553
+ "Using `use_fast=True` but `torchvision` is not available. Falling back to the slow image processor."
554
+ )
555
+ use_fast = False
556
+ if use_fast:
557
+ for image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.values():
558
+ if image_processor_type in image_processors:
559
+ break
560
+ else:
561
+ image_processor_type = image_processor_type[:-4]
562
+ use_fast = False
563
+ logger.warning_once(
564
+ "`use_fast` is set to `True` but the image processor class does not have a fast version. "
565
+ " Falling back to the slow version."
566
+ )
567
+ image_processor_class = get_image_processor_class_from_name(image_processor_type)
568
+ else:
569
+ image_processor_type_slow = image_processor_type.removesuffix("Fast")
570
+ image_processor_class = get_image_processor_class_from_name(image_processor_type_slow)
571
+ if image_processor_class is None and image_processor_type.endswith("Fast"):
572
+ raise ValueError(
573
+ f"`{image_processor_type}` does not have a slow version. Please set `use_fast=True` when instantiating the processor."
574
+ )
575
+
576
+ has_remote_code = image_processor_auto_map is not None
577
+ has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING
578
+ if has_remote_code:
579
+ if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple):
580
+ # In some configs, only the slow image processor class is stored
581
+ image_processor_auto_map = (image_processor_auto_map, None)
582
+ if use_fast and image_processor_auto_map[1] is not None:
583
+ class_ref = image_processor_auto_map[1]
584
+ else:
585
+ class_ref = image_processor_auto_map[0]
586
+ if "--" in class_ref:
587
+ upstream_repo = class_ref.split("--")[0]
588
+ else:
589
+ upstream_repo = None
590
+ trust_remote_code = resolve_trust_remote_code(
591
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
592
+ )
593
+
594
+ if has_remote_code and trust_remote_code:
595
+ if not use_fast and image_processor_auto_map[1] is not None:
596
+ _warning_fast_image_processor_available(image_processor_auto_map[1])
597
+
598
+ image_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
599
+ _ = kwargs.pop("code_revision", None)
600
+ image_processor_class.register_for_auto_class()
601
+ return image_processor_class.from_dict(config_dict, **kwargs)
602
+ elif image_processor_class is not None:
603
+ return image_processor_class.from_dict(config_dict, **kwargs)
604
+ # Last try: we use the IMAGE_PROCESSOR_MAPPING.
605
+ elif type(config) in IMAGE_PROCESSOR_MAPPING:
606
+ image_processor_tuple = IMAGE_PROCESSOR_MAPPING[type(config)]
607
+
608
+ image_processor_class_py, image_processor_class_fast = image_processor_tuple
609
+
610
+ if not use_fast and image_processor_class_fast is not None:
611
+ _warning_fast_image_processor_available(image_processor_class_fast)
612
+
613
+ if image_processor_class_fast and (use_fast or image_processor_class_py is None):
614
+ return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
615
+ else:
616
+ if image_processor_class_py is not None:
617
+ return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
618
+ else:
619
+ raise ValueError(
620
+ "This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
621
+ )
622
+ raise ValueError(
623
+ f"Unrecognized image processor in {pretrained_model_name_or_path}. Should have a "
624
+ f"`image_processor_type` key in its {IMAGE_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
625
+ f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in IMAGE_PROCESSOR_MAPPING_NAMES)}"
626
+ )
627
+
628
+ @staticmethod
629
+ def register(
630
+ config_class,
631
+ image_processor_class=None,
632
+ slow_image_processor_class=None,
633
+ fast_image_processor_class=None,
634
+ exist_ok=False,
635
+ ):
636
+ """
637
+ Register a new image processor for this class.
638
+
639
+ Args:
640
+ config_class ([`PretrainedConfig`]):
641
+ The configuration corresponding to the model to register.
642
+ image_processor_class ([`ImageProcessingMixin`]): The image processor to register.
643
+ """
644
+ if image_processor_class is not None:
645
+ if slow_image_processor_class is not None:
646
+ raise ValueError("Cannot specify both image_processor_class and slow_image_processor_class")
647
+ warnings.warn(
648
+ "The image_processor_class argument is deprecated and will be removed in v4.42. Please use `slow_image_processor_class`, or `fast_image_processor_class` instead",
649
+ FutureWarning,
650
+ )
651
+ slow_image_processor_class = image_processor_class
652
+
653
+ if slow_image_processor_class is None and fast_image_processor_class is None:
654
+ raise ValueError("You need to specify either slow_image_processor_class or fast_image_processor_class")
655
+ if slow_image_processor_class is not None and issubclass(slow_image_processor_class, BaseImageProcessorFast):
656
+ raise ValueError("You passed a fast image processor in as the `slow_image_processor_class`.")
657
+ if fast_image_processor_class is not None and not issubclass(
658
+ fast_image_processor_class, BaseImageProcessorFast
659
+ ):
660
+ raise ValueError("The `fast_image_processor_class` should inherit from `BaseImageProcessorFast`.")
661
+
662
+ if (
663
+ slow_image_processor_class is not None
664
+ and fast_image_processor_class is not None
665
+ and issubclass(fast_image_processor_class, BaseImageProcessorFast)
666
+ and fast_image_processor_class.slow_image_processor_class != slow_image_processor_class
667
+ ):
668
+ raise ValueError(
669
+ "The fast processor class you are passing has a `slow_image_processor_class` attribute that is not "
670
+ "consistent with the slow processor class you passed (fast tokenizer has "
671
+ f"{fast_image_processor_class.slow_image_processor_class} and you passed {slow_image_processor_class}. Fix one of those "
672
+ "so they match!"
673
+ )
674
+
675
+ # Avoid resetting a set slow/fast image processor if we are passing just the other ones.
676
+ if config_class in IMAGE_PROCESSOR_MAPPING._extra_content:
677
+ existing_slow, existing_fast = IMAGE_PROCESSOR_MAPPING[config_class]
678
+ if slow_image_processor_class is None:
679
+ slow_image_processor_class = existing_slow
680
+ if fast_image_processor_class is None:
681
+ fast_image_processor_class = existing_fast
682
+
683
+ IMAGE_PROCESSOR_MAPPING.register(
684
+ config_class, (slow_image_processor_class, fast_image_processor_class), exist_ok=exist_ok
685
+ )
686
+
687
+
688
+ __all__ = ["IMAGE_PROCESSOR_MAPPING", "AutoImageProcessor"]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_auto.py ADDED
The diff for this file is too large to render. See raw diff
 
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_flax_auto.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Auto Model class."""
16
+
17
+ from collections import OrderedDict
18
+
19
+ from ...utils import logging
20
+ from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
21
+ from .configuration_auto import CONFIG_MAPPING_NAMES
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
28
+ [
29
+ # Base model mapping
30
+ ("albert", "FlaxAlbertModel"),
31
+ ("bart", "FlaxBartModel"),
32
+ ("beit", "FlaxBeitModel"),
33
+ ("bert", "FlaxBertModel"),
34
+ ("big_bird", "FlaxBigBirdModel"),
35
+ ("blenderbot", "FlaxBlenderbotModel"),
36
+ ("blenderbot-small", "FlaxBlenderbotSmallModel"),
37
+ ("bloom", "FlaxBloomModel"),
38
+ ("clip", "FlaxCLIPModel"),
39
+ ("dinov2", "FlaxDinov2Model"),
40
+ ("distilbert", "FlaxDistilBertModel"),
41
+ ("electra", "FlaxElectraModel"),
42
+ ("gemma", "FlaxGemmaModel"),
43
+ ("gpt-sw3", "FlaxGPT2Model"),
44
+ ("gpt2", "FlaxGPT2Model"),
45
+ ("gpt_neo", "FlaxGPTNeoModel"),
46
+ ("gptj", "FlaxGPTJModel"),
47
+ ("llama", "FlaxLlamaModel"),
48
+ ("longt5", "FlaxLongT5Model"),
49
+ ("marian", "FlaxMarianModel"),
50
+ ("mbart", "FlaxMBartModel"),
51
+ ("mistral", "FlaxMistralModel"),
52
+ ("mt5", "FlaxMT5Model"),
53
+ ("opt", "FlaxOPTModel"),
54
+ ("pegasus", "FlaxPegasusModel"),
55
+ ("regnet", "FlaxRegNetModel"),
56
+ ("resnet", "FlaxResNetModel"),
57
+ ("roberta", "FlaxRobertaModel"),
58
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"),
59
+ ("roformer", "FlaxRoFormerModel"),
60
+ ("t5", "FlaxT5Model"),
61
+ ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
62
+ ("vit", "FlaxViTModel"),
63
+ ("wav2vec2", "FlaxWav2Vec2Model"),
64
+ ("whisper", "FlaxWhisperModel"),
65
+ ("xglm", "FlaxXGLMModel"),
66
+ ("xlm-roberta", "FlaxXLMRobertaModel"),
67
+ ]
68
+ )
69
+
70
+ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
71
+ [
72
+ # Model for pre-training mapping
73
+ ("albert", "FlaxAlbertForPreTraining"),
74
+ ("bart", "FlaxBartForConditionalGeneration"),
75
+ ("bert", "FlaxBertForPreTraining"),
76
+ ("big_bird", "FlaxBigBirdForPreTraining"),
77
+ ("electra", "FlaxElectraForPreTraining"),
78
+ ("longt5", "FlaxLongT5ForConditionalGeneration"),
79
+ ("mbart", "FlaxMBartForConditionalGeneration"),
80
+ ("mt5", "FlaxMT5ForConditionalGeneration"),
81
+ ("roberta", "FlaxRobertaForMaskedLM"),
82
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
83
+ ("roformer", "FlaxRoFormerForMaskedLM"),
84
+ ("t5", "FlaxT5ForConditionalGeneration"),
85
+ ("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
86
+ ("whisper", "FlaxWhisperForConditionalGeneration"),
87
+ ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
88
+ ]
89
+ )
90
+
91
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
92
+ [
93
+ # Model for Masked LM mapping
94
+ ("albert", "FlaxAlbertForMaskedLM"),
95
+ ("bart", "FlaxBartForConditionalGeneration"),
96
+ ("bert", "FlaxBertForMaskedLM"),
97
+ ("big_bird", "FlaxBigBirdForMaskedLM"),
98
+ ("distilbert", "FlaxDistilBertForMaskedLM"),
99
+ ("electra", "FlaxElectraForMaskedLM"),
100
+ ("mbart", "FlaxMBartForConditionalGeneration"),
101
+ ("roberta", "FlaxRobertaForMaskedLM"),
102
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
103
+ ("roformer", "FlaxRoFormerForMaskedLM"),
104
+ ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
105
+ ]
106
+ )
107
+
108
+ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
109
+ [
110
+ # Model for Seq2Seq Causal LM mapping
111
+ ("bart", "FlaxBartForConditionalGeneration"),
112
+ ("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
113
+ ("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
114
+ ("encoder-decoder", "FlaxEncoderDecoderModel"),
115
+ ("longt5", "FlaxLongT5ForConditionalGeneration"),
116
+ ("marian", "FlaxMarianMTModel"),
117
+ ("mbart", "FlaxMBartForConditionalGeneration"),
118
+ ("mt5", "FlaxMT5ForConditionalGeneration"),
119
+ ("pegasus", "FlaxPegasusForConditionalGeneration"),
120
+ ("t5", "FlaxT5ForConditionalGeneration"),
121
+ ]
122
+ )
123
+
124
+ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
125
+ [
126
+ # Model for Image-classification
127
+ ("beit", "FlaxBeitForImageClassification"),
128
+ ("dinov2", "FlaxDinov2ForImageClassification"),
129
+ ("regnet", "FlaxRegNetForImageClassification"),
130
+ ("resnet", "FlaxResNetForImageClassification"),
131
+ ("vit", "FlaxViTForImageClassification"),
132
+ ]
133
+ )
134
+
135
+ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
136
+ [
137
+ ("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"),
138
+ ]
139
+ )
140
+
141
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
142
+ [
143
+ # Model for Causal LM mapping
144
+ ("bart", "FlaxBartForCausalLM"),
145
+ ("bert", "FlaxBertForCausalLM"),
146
+ ("big_bird", "FlaxBigBirdForCausalLM"),
147
+ ("bloom", "FlaxBloomForCausalLM"),
148
+ ("electra", "FlaxElectraForCausalLM"),
149
+ ("gemma", "FlaxGemmaForCausalLM"),
150
+ ("gpt-sw3", "FlaxGPT2LMHeadModel"),
151
+ ("gpt2", "FlaxGPT2LMHeadModel"),
152
+ ("gpt_neo", "FlaxGPTNeoForCausalLM"),
153
+ ("gptj", "FlaxGPTJForCausalLM"),
154
+ ("llama", "FlaxLlamaForCausalLM"),
155
+ ("mistral", "FlaxMistralForCausalLM"),
156
+ ("opt", "FlaxOPTForCausalLM"),
157
+ ("roberta", "FlaxRobertaForCausalLM"),
158
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
159
+ ("xglm", "FlaxXGLMForCausalLM"),
160
+ ("xlm-roberta", "FlaxXLMRobertaForCausalLM"),
161
+ ]
162
+ )
163
+
164
+ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
165
+ [
166
+ # Model for Sequence Classification mapping
167
+ ("albert", "FlaxAlbertForSequenceClassification"),
168
+ ("bart", "FlaxBartForSequenceClassification"),
169
+ ("bert", "FlaxBertForSequenceClassification"),
170
+ ("big_bird", "FlaxBigBirdForSequenceClassification"),
171
+ ("distilbert", "FlaxDistilBertForSequenceClassification"),
172
+ ("electra", "FlaxElectraForSequenceClassification"),
173
+ ("mbart", "FlaxMBartForSequenceClassification"),
174
+ ("roberta", "FlaxRobertaForSequenceClassification"),
175
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForSequenceClassification"),
176
+ ("roformer", "FlaxRoFormerForSequenceClassification"),
177
+ ("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
178
+ ]
179
+ )
180
+
181
+ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
182
+ [
183
+ # Model for Question Answering mapping
184
+ ("albert", "FlaxAlbertForQuestionAnswering"),
185
+ ("bart", "FlaxBartForQuestionAnswering"),
186
+ ("bert", "FlaxBertForQuestionAnswering"),
187
+ ("big_bird", "FlaxBigBirdForQuestionAnswering"),
188
+ ("distilbert", "FlaxDistilBertForQuestionAnswering"),
189
+ ("electra", "FlaxElectraForQuestionAnswering"),
190
+ ("mbart", "FlaxMBartForQuestionAnswering"),
191
+ ("roberta", "FlaxRobertaForQuestionAnswering"),
192
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForQuestionAnswering"),
193
+ ("roformer", "FlaxRoFormerForQuestionAnswering"),
194
+ ("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
195
+ ]
196
+ )
197
+
198
+ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
199
+ [
200
+ # Model for Token Classification mapping
201
+ ("albert", "FlaxAlbertForTokenClassification"),
202
+ ("bert", "FlaxBertForTokenClassification"),
203
+ ("big_bird", "FlaxBigBirdForTokenClassification"),
204
+ ("distilbert", "FlaxDistilBertForTokenClassification"),
205
+ ("electra", "FlaxElectraForTokenClassification"),
206
+ ("roberta", "FlaxRobertaForTokenClassification"),
207
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForTokenClassification"),
208
+ ("roformer", "FlaxRoFormerForTokenClassification"),
209
+ ("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
210
+ ]
211
+ )
212
+
213
+ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
214
+ [
215
+ # Model for Multiple Choice mapping
216
+ ("albert", "FlaxAlbertForMultipleChoice"),
217
+ ("bert", "FlaxBertForMultipleChoice"),
218
+ ("big_bird", "FlaxBigBirdForMultipleChoice"),
219
+ ("distilbert", "FlaxDistilBertForMultipleChoice"),
220
+ ("electra", "FlaxElectraForMultipleChoice"),
221
+ ("roberta", "FlaxRobertaForMultipleChoice"),
222
+ ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMultipleChoice"),
223
+ ("roformer", "FlaxRoFormerForMultipleChoice"),
224
+ ("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
225
+ ]
226
+ )
227
+
228
+ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
229
+ [
230
+ ("bert", "FlaxBertForNextSentencePrediction"),
231
+ ]
232
+ )
233
+
234
+ FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
235
+ [
236
+ ("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
237
+ ("whisper", "FlaxWhisperForConditionalGeneration"),
238
+ ]
239
+ )
240
+
241
+ FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
242
+ [
243
+ ("whisper", "FlaxWhisperForAudioClassification"),
244
+ ]
245
+ )
246
+
247
+ FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
248
+ FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
249
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
250
+ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
251
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
252
+ )
253
+ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
254
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
255
+ )
256
+ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
257
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
258
+ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
259
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
260
+ )
261
+ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
262
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
263
+ )
264
+ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
265
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
266
+ )
267
+ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
268
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
269
+ )
270
+ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
271
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
272
+ )
273
+ FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
274
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
275
+ )
276
+ FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
277
+ CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
278
+ )
279
+
280
+
281
+ class FlaxAutoModel(_BaseAutoModelClass):
282
+ _model_mapping = FLAX_MODEL_MAPPING
283
+
284
+
285
+ FlaxAutoModel = auto_class_update(FlaxAutoModel)
286
+
287
+
288
+ class FlaxAutoModelForPreTraining(_BaseAutoModelClass):
289
+ _model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING
290
+
291
+
292
+ FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining")
293
+
294
+
295
+ class FlaxAutoModelForCausalLM(_BaseAutoModelClass):
296
+ _model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
297
+
298
+
299
+ FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling")
300
+
301
+
302
+ class FlaxAutoModelForMaskedLM(_BaseAutoModelClass):
303
+ _model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING
304
+
305
+
306
+ FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling")
307
+
308
+
309
+ class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass):
310
+ _model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
311
+
312
+
313
+ FlaxAutoModelForSeq2SeqLM = auto_class_update(
314
+ FlaxAutoModelForSeq2SeqLM,
315
+ head_doc="sequence-to-sequence language modeling",
316
+ checkpoint_for_example="google-t5/t5-base",
317
+ )
318
+
319
+
320
+ class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass):
321
+ _model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
322
+
323
+
324
+ FlaxAutoModelForSequenceClassification = auto_class_update(
325
+ FlaxAutoModelForSequenceClassification, head_doc="sequence classification"
326
+ )
327
+
328
+
329
+ class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass):
330
+ _model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
331
+
332
+
333
+ FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering")
334
+
335
+
336
+ class FlaxAutoModelForTokenClassification(_BaseAutoModelClass):
337
+ _model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
338
+
339
+
340
+ FlaxAutoModelForTokenClassification = auto_class_update(
341
+ FlaxAutoModelForTokenClassification, head_doc="token classification"
342
+ )
343
+
344
+
345
+ class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass):
346
+ _model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
347
+
348
+
349
+ FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice")
350
+
351
+
352
+ class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass):
353
+ _model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
354
+
355
+
356
+ FlaxAutoModelForNextSentencePrediction = auto_class_update(
357
+ FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
358
+ )
359
+
360
+
361
+ class FlaxAutoModelForImageClassification(_BaseAutoModelClass):
362
+ _model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
363
+
364
+
365
+ FlaxAutoModelForImageClassification = auto_class_update(
366
+ FlaxAutoModelForImageClassification, head_doc="image classification"
367
+ )
368
+
369
+
370
+ class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
371
+ _model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
372
+
373
+
374
+ FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
375
+
376
+
377
+ class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
378
+ _model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
379
+
380
+
381
+ FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
382
+ FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
383
+ )
384
+
385
+ __all__ = [
386
+ "FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
387
+ "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
388
+ "FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
389
+ "FLAX_MODEL_FOR_MASKED_LM_MAPPING",
390
+ "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
391
+ "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
392
+ "FLAX_MODEL_FOR_PRETRAINING_MAPPING",
393
+ "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
394
+ "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
395
+ "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
396
+ "FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
397
+ "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
398
+ "FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
399
+ "FLAX_MODEL_MAPPING",
400
+ "FlaxAutoModel",
401
+ "FlaxAutoModelForCausalLM",
402
+ "FlaxAutoModelForImageClassification",
403
+ "FlaxAutoModelForMaskedLM",
404
+ "FlaxAutoModelForMultipleChoice",
405
+ "FlaxAutoModelForNextSentencePrediction",
406
+ "FlaxAutoModelForPreTraining",
407
+ "FlaxAutoModelForQuestionAnswering",
408
+ "FlaxAutoModelForSeq2SeqLM",
409
+ "FlaxAutoModelForSequenceClassification",
410
+ "FlaxAutoModelForSpeechSeq2Seq",
411
+ "FlaxAutoModelForTokenClassification",
412
+ "FlaxAutoModelForVision2Seq",
413
+ ]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_tf_auto.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Auto Model class."""
16
+
17
+ import warnings
18
+ from collections import OrderedDict
19
+
20
+ from ...utils import logging
21
+ from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
22
+ from .configuration_auto import CONFIG_MAPPING_NAMES
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ TF_MODEL_MAPPING_NAMES = OrderedDict(
29
+ [
30
+ # Base model mapping
31
+ ("albert", "TFAlbertModel"),
32
+ ("bart", "TFBartModel"),
33
+ ("bert", "TFBertModel"),
34
+ ("blenderbot", "TFBlenderbotModel"),
35
+ ("blenderbot-small", "TFBlenderbotSmallModel"),
36
+ ("blip", "TFBlipModel"),
37
+ ("camembert", "TFCamembertModel"),
38
+ ("clip", "TFCLIPModel"),
39
+ ("convbert", "TFConvBertModel"),
40
+ ("convnext", "TFConvNextModel"),
41
+ ("convnextv2", "TFConvNextV2Model"),
42
+ ("ctrl", "TFCTRLModel"),
43
+ ("cvt", "TFCvtModel"),
44
+ ("data2vec-vision", "TFData2VecVisionModel"),
45
+ ("deberta", "TFDebertaModel"),
46
+ ("deberta-v2", "TFDebertaV2Model"),
47
+ ("deit", "TFDeiTModel"),
48
+ ("distilbert", "TFDistilBertModel"),
49
+ ("dpr", "TFDPRQuestionEncoder"),
50
+ ("efficientformer", "TFEfficientFormerModel"),
51
+ ("electra", "TFElectraModel"),
52
+ ("esm", "TFEsmModel"),
53
+ ("flaubert", "TFFlaubertModel"),
54
+ ("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
55
+ ("gpt-sw3", "TFGPT2Model"),
56
+ ("gpt2", "TFGPT2Model"),
57
+ ("gptj", "TFGPTJModel"),
58
+ ("groupvit", "TFGroupViTModel"),
59
+ ("hubert", "TFHubertModel"),
60
+ ("idefics", "TFIdeficsModel"),
61
+ ("layoutlm", "TFLayoutLMModel"),
62
+ ("layoutlmv3", "TFLayoutLMv3Model"),
63
+ ("led", "TFLEDModel"),
64
+ ("longformer", "TFLongformerModel"),
65
+ ("lxmert", "TFLxmertModel"),
66
+ ("marian", "TFMarianModel"),
67
+ ("mbart", "TFMBartModel"),
68
+ ("mistral", "TFMistralModel"),
69
+ ("mobilebert", "TFMobileBertModel"),
70
+ ("mobilevit", "TFMobileViTModel"),
71
+ ("mpnet", "TFMPNetModel"),
72
+ ("mt5", "TFMT5Model"),
73
+ ("openai-gpt", "TFOpenAIGPTModel"),
74
+ ("opt", "TFOPTModel"),
75
+ ("pegasus", "TFPegasusModel"),
76
+ ("regnet", "TFRegNetModel"),
77
+ ("rembert", "TFRemBertModel"),
78
+ ("resnet", "TFResNetModel"),
79
+ ("roberta", "TFRobertaModel"),
80
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
81
+ ("roformer", "TFRoFormerModel"),
82
+ ("sam", "TFSamModel"),
83
+ ("sam_vision_model", "TFSamVisionModel"),
84
+ ("segformer", "TFSegformerModel"),
85
+ ("speech_to_text", "TFSpeech2TextModel"),
86
+ ("swiftformer", "TFSwiftFormerModel"),
87
+ ("swin", "TFSwinModel"),
88
+ ("t5", "TFT5Model"),
89
+ ("tapas", "TFTapasModel"),
90
+ ("transfo-xl", "TFTransfoXLModel"),
91
+ ("vision-text-dual-encoder", "TFVisionTextDualEncoderModel"),
92
+ ("vit", "TFViTModel"),
93
+ ("vit_mae", "TFViTMAEModel"),
94
+ ("wav2vec2", "TFWav2Vec2Model"),
95
+ ("whisper", "TFWhisperModel"),
96
+ ("xglm", "TFXGLMModel"),
97
+ ("xlm", "TFXLMModel"),
98
+ ("xlm-roberta", "TFXLMRobertaModel"),
99
+ ("xlnet", "TFXLNetModel"),
100
+ ]
101
+ )
102
+
103
+ TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
104
+ [
105
+ # Model for pre-training mapping
106
+ ("albert", "TFAlbertForPreTraining"),
107
+ ("bart", "TFBartForConditionalGeneration"),
108
+ ("bert", "TFBertForPreTraining"),
109
+ ("camembert", "TFCamembertForMaskedLM"),
110
+ ("ctrl", "TFCTRLLMHeadModel"),
111
+ ("distilbert", "TFDistilBertForMaskedLM"),
112
+ ("electra", "TFElectraForPreTraining"),
113
+ ("flaubert", "TFFlaubertWithLMHeadModel"),
114
+ ("funnel", "TFFunnelForPreTraining"),
115
+ ("gpt-sw3", "TFGPT2LMHeadModel"),
116
+ ("gpt2", "TFGPT2LMHeadModel"),
117
+ ("idefics", "TFIdeficsForVisionText2Text"),
118
+ ("layoutlm", "TFLayoutLMForMaskedLM"),
119
+ ("lxmert", "TFLxmertForPreTraining"),
120
+ ("mobilebert", "TFMobileBertForPreTraining"),
121
+ ("mpnet", "TFMPNetForMaskedLM"),
122
+ ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
123
+ ("roberta", "TFRobertaForMaskedLM"),
124
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
125
+ ("t5", "TFT5ForConditionalGeneration"),
126
+ ("tapas", "TFTapasForMaskedLM"),
127
+ ("transfo-xl", "TFTransfoXLLMHeadModel"),
128
+ ("vit_mae", "TFViTMAEForPreTraining"),
129
+ ("xlm", "TFXLMWithLMHeadModel"),
130
+ ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
131
+ ("xlnet", "TFXLNetLMHeadModel"),
132
+ ]
133
+ )
134
+
135
+ TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
136
+ [
137
+ # Model with LM heads mapping
138
+ ("albert", "TFAlbertForMaskedLM"),
139
+ ("bart", "TFBartForConditionalGeneration"),
140
+ ("bert", "TFBertForMaskedLM"),
141
+ ("camembert", "TFCamembertForMaskedLM"),
142
+ ("convbert", "TFConvBertForMaskedLM"),
143
+ ("ctrl", "TFCTRLLMHeadModel"),
144
+ ("distilbert", "TFDistilBertForMaskedLM"),
145
+ ("electra", "TFElectraForMaskedLM"),
146
+ ("esm", "TFEsmForMaskedLM"),
147
+ ("flaubert", "TFFlaubertWithLMHeadModel"),
148
+ ("funnel", "TFFunnelForMaskedLM"),
149
+ ("gpt-sw3", "TFGPT2LMHeadModel"),
150
+ ("gpt2", "TFGPT2LMHeadModel"),
151
+ ("gptj", "TFGPTJForCausalLM"),
152
+ ("layoutlm", "TFLayoutLMForMaskedLM"),
153
+ ("led", "TFLEDForConditionalGeneration"),
154
+ ("longformer", "TFLongformerForMaskedLM"),
155
+ ("marian", "TFMarianMTModel"),
156
+ ("mobilebert", "TFMobileBertForMaskedLM"),
157
+ ("mpnet", "TFMPNetForMaskedLM"),
158
+ ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
159
+ ("rembert", "TFRemBertForMaskedLM"),
160
+ ("roberta", "TFRobertaForMaskedLM"),
161
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
162
+ ("roformer", "TFRoFormerForMaskedLM"),
163
+ ("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
164
+ ("t5", "TFT5ForConditionalGeneration"),
165
+ ("tapas", "TFTapasForMaskedLM"),
166
+ ("transfo-xl", "TFTransfoXLLMHeadModel"),
167
+ ("whisper", "TFWhisperForConditionalGeneration"),
168
+ ("xlm", "TFXLMWithLMHeadModel"),
169
+ ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
170
+ ("xlnet", "TFXLNetLMHeadModel"),
171
+ ]
172
+ )
173
+
174
+ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
175
+ [
176
+ # Model for Causal LM mapping
177
+ ("bert", "TFBertLMHeadModel"),
178
+ ("camembert", "TFCamembertForCausalLM"),
179
+ ("ctrl", "TFCTRLLMHeadModel"),
180
+ ("gpt-sw3", "TFGPT2LMHeadModel"),
181
+ ("gpt2", "TFGPT2LMHeadModel"),
182
+ ("gptj", "TFGPTJForCausalLM"),
183
+ ("mistral", "TFMistralForCausalLM"),
184
+ ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
185
+ ("opt", "TFOPTForCausalLM"),
186
+ ("rembert", "TFRemBertForCausalLM"),
187
+ ("roberta", "TFRobertaForCausalLM"),
188
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForCausalLM"),
189
+ ("roformer", "TFRoFormerForCausalLM"),
190
+ ("transfo-xl", "TFTransfoXLLMHeadModel"),
191
+ ("xglm", "TFXGLMForCausalLM"),
192
+ ("xlm", "TFXLMWithLMHeadModel"),
193
+ ("xlm-roberta", "TFXLMRobertaForCausalLM"),
194
+ ("xlnet", "TFXLNetLMHeadModel"),
195
+ ]
196
+ )
197
+
198
+ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
199
+ [
200
+ ("deit", "TFDeiTForMaskedImageModeling"),
201
+ ("swin", "TFSwinForMaskedImageModeling"),
202
+ ]
203
+ )
204
+
205
+ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
206
+ [
207
+ # Model for Image-classsification
208
+ ("convnext", "TFConvNextForImageClassification"),
209
+ ("convnextv2", "TFConvNextV2ForImageClassification"),
210
+ ("cvt", "TFCvtForImageClassification"),
211
+ ("data2vec-vision", "TFData2VecVisionForImageClassification"),
212
+ ("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
213
+ (
214
+ "efficientformer",
215
+ ("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"),
216
+ ),
217
+ ("mobilevit", "TFMobileViTForImageClassification"),
218
+ ("regnet", "TFRegNetForImageClassification"),
219
+ ("resnet", "TFResNetForImageClassification"),
220
+ ("segformer", "TFSegformerForImageClassification"),
221
+ ("swiftformer", "TFSwiftFormerForImageClassification"),
222
+ ("swin", "TFSwinForImageClassification"),
223
+ ("vit", "TFViTForImageClassification"),
224
+ ]
225
+ )
226
+
227
+
228
+ TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
229
+ [
230
+ # Model for Zero Shot Image Classification mapping
231
+ ("blip", "TFBlipModel"),
232
+ ("clip", "TFCLIPModel"),
233
+ ]
234
+ )
235
+
236
+
237
+ TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
238
+ [
239
+ # Model for Semantic Segmentation mapping
240
+ ("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
241
+ ("mobilevit", "TFMobileViTForSemanticSegmentation"),
242
+ ("segformer", "TFSegformerForSemanticSegmentation"),
243
+ ]
244
+ )
245
+
246
+ TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
247
+ [
248
+ ("blip", "TFBlipForConditionalGeneration"),
249
+ ("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
250
+ ]
251
+ )
252
+
253
+ TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
254
+ [
255
+ # Model for Masked LM mapping
256
+ ("albert", "TFAlbertForMaskedLM"),
257
+ ("bert", "TFBertForMaskedLM"),
258
+ ("camembert", "TFCamembertForMaskedLM"),
259
+ ("convbert", "TFConvBertForMaskedLM"),
260
+ ("deberta", "TFDebertaForMaskedLM"),
261
+ ("deberta-v2", "TFDebertaV2ForMaskedLM"),
262
+ ("distilbert", "TFDistilBertForMaskedLM"),
263
+ ("electra", "TFElectraForMaskedLM"),
264
+ ("esm", "TFEsmForMaskedLM"),
265
+ ("flaubert", "TFFlaubertWithLMHeadModel"),
266
+ ("funnel", "TFFunnelForMaskedLM"),
267
+ ("layoutlm", "TFLayoutLMForMaskedLM"),
268
+ ("longformer", "TFLongformerForMaskedLM"),
269
+ ("mobilebert", "TFMobileBertForMaskedLM"),
270
+ ("mpnet", "TFMPNetForMaskedLM"),
271
+ ("rembert", "TFRemBertForMaskedLM"),
272
+ ("roberta", "TFRobertaForMaskedLM"),
273
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
274
+ ("roformer", "TFRoFormerForMaskedLM"),
275
+ ("tapas", "TFTapasForMaskedLM"),
276
+ ("xlm", "TFXLMWithLMHeadModel"),
277
+ ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
278
+ ]
279
+ )
280
+
281
+ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
282
+ [
283
+ # Model for Seq2Seq Causal LM mapping
284
+ ("bart", "TFBartForConditionalGeneration"),
285
+ ("blenderbot", "TFBlenderbotForConditionalGeneration"),
286
+ ("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
287
+ ("encoder-decoder", "TFEncoderDecoderModel"),
288
+ ("led", "TFLEDForConditionalGeneration"),
289
+ ("marian", "TFMarianMTModel"),
290
+ ("mbart", "TFMBartForConditionalGeneration"),
291
+ ("mt5", "TFMT5ForConditionalGeneration"),
292
+ ("pegasus", "TFPegasusForConditionalGeneration"),
293
+ ("t5", "TFT5ForConditionalGeneration"),
294
+ ]
295
+ )
296
+
297
+ TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
298
+ [
299
+ ("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
300
+ ("whisper", "TFWhisperForConditionalGeneration"),
301
+ ]
302
+ )
303
+
304
+ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
305
+ [
306
+ # Model for Sequence Classification mapping
307
+ ("albert", "TFAlbertForSequenceClassification"),
308
+ ("bart", "TFBartForSequenceClassification"),
309
+ ("bert", "TFBertForSequenceClassification"),
310
+ ("camembert", "TFCamembertForSequenceClassification"),
311
+ ("convbert", "TFConvBertForSequenceClassification"),
312
+ ("ctrl", "TFCTRLForSequenceClassification"),
313
+ ("deberta", "TFDebertaForSequenceClassification"),
314
+ ("deberta-v2", "TFDebertaV2ForSequenceClassification"),
315
+ ("distilbert", "TFDistilBertForSequenceClassification"),
316
+ ("electra", "TFElectraForSequenceClassification"),
317
+ ("esm", "TFEsmForSequenceClassification"),
318
+ ("flaubert", "TFFlaubertForSequenceClassification"),
319
+ ("funnel", "TFFunnelForSequenceClassification"),
320
+ ("gpt-sw3", "TFGPT2ForSequenceClassification"),
321
+ ("gpt2", "TFGPT2ForSequenceClassification"),
322
+ ("gptj", "TFGPTJForSequenceClassification"),
323
+ ("layoutlm", "TFLayoutLMForSequenceClassification"),
324
+ ("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"),
325
+ ("longformer", "TFLongformerForSequenceClassification"),
326
+ ("mistral", "TFMistralForSequenceClassification"),
327
+ ("mobilebert", "TFMobileBertForSequenceClassification"),
328
+ ("mpnet", "TFMPNetForSequenceClassification"),
329
+ ("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
330
+ ("rembert", "TFRemBertForSequenceClassification"),
331
+ ("roberta", "TFRobertaForSequenceClassification"),
332
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForSequenceClassification"),
333
+ ("roformer", "TFRoFormerForSequenceClassification"),
334
+ ("tapas", "TFTapasForSequenceClassification"),
335
+ ("transfo-xl", "TFTransfoXLForSequenceClassification"),
336
+ ("xlm", "TFXLMForSequenceClassification"),
337
+ ("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
338
+ ("xlnet", "TFXLNetForSequenceClassification"),
339
+ ]
340
+ )
341
+
342
+ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
343
+ [
344
+ # Model for Question Answering mapping
345
+ ("albert", "TFAlbertForQuestionAnswering"),
346
+ ("bert", "TFBertForQuestionAnswering"),
347
+ ("camembert", "TFCamembertForQuestionAnswering"),
348
+ ("convbert", "TFConvBertForQuestionAnswering"),
349
+ ("deberta", "TFDebertaForQuestionAnswering"),
350
+ ("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
351
+ ("distilbert", "TFDistilBertForQuestionAnswering"),
352
+ ("electra", "TFElectraForQuestionAnswering"),
353
+ ("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
354
+ ("funnel", "TFFunnelForQuestionAnswering"),
355
+ ("gptj", "TFGPTJForQuestionAnswering"),
356
+ ("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
357
+ ("longformer", "TFLongformerForQuestionAnswering"),
358
+ ("mobilebert", "TFMobileBertForQuestionAnswering"),
359
+ ("mpnet", "TFMPNetForQuestionAnswering"),
360
+ ("rembert", "TFRemBertForQuestionAnswering"),
361
+ ("roberta", "TFRobertaForQuestionAnswering"),
362
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForQuestionAnswering"),
363
+ ("roformer", "TFRoFormerForQuestionAnswering"),
364
+ ("xlm", "TFXLMForQuestionAnsweringSimple"),
365
+ ("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
366
+ ("xlnet", "TFXLNetForQuestionAnsweringSimple"),
367
+ ]
368
+ )
369
+ TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")])
370
+
371
+ TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
372
+ [
373
+ ("layoutlm", "TFLayoutLMForQuestionAnswering"),
374
+ ("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
375
+ ]
376
+ )
377
+
378
+
379
+ TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
380
+ [
381
+ # Model for Table Question Answering mapping
382
+ ("tapas", "TFTapasForQuestionAnswering"),
383
+ ]
384
+ )
385
+
386
+ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
387
+ [
388
+ # Model for Token Classification mapping
389
+ ("albert", "TFAlbertForTokenClassification"),
390
+ ("bert", "TFBertForTokenClassification"),
391
+ ("camembert", "TFCamembertForTokenClassification"),
392
+ ("convbert", "TFConvBertForTokenClassification"),
393
+ ("deberta", "TFDebertaForTokenClassification"),
394
+ ("deberta-v2", "TFDebertaV2ForTokenClassification"),
395
+ ("distilbert", "TFDistilBertForTokenClassification"),
396
+ ("electra", "TFElectraForTokenClassification"),
397
+ ("esm", "TFEsmForTokenClassification"),
398
+ ("flaubert", "TFFlaubertForTokenClassification"),
399
+ ("funnel", "TFFunnelForTokenClassification"),
400
+ ("layoutlm", "TFLayoutLMForTokenClassification"),
401
+ ("layoutlmv3", "TFLayoutLMv3ForTokenClassification"),
402
+ ("longformer", "TFLongformerForTokenClassification"),
403
+ ("mobilebert", "TFMobileBertForTokenClassification"),
404
+ ("mpnet", "TFMPNetForTokenClassification"),
405
+ ("rembert", "TFRemBertForTokenClassification"),
406
+ ("roberta", "TFRobertaForTokenClassification"),
407
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForTokenClassification"),
408
+ ("roformer", "TFRoFormerForTokenClassification"),
409
+ ("xlm", "TFXLMForTokenClassification"),
410
+ ("xlm-roberta", "TFXLMRobertaForTokenClassification"),
411
+ ("xlnet", "TFXLNetForTokenClassification"),
412
+ ]
413
+ )
414
+
415
+ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
416
+ [
417
+ # Model for Multiple Choice mapping
418
+ ("albert", "TFAlbertForMultipleChoice"),
419
+ ("bert", "TFBertForMultipleChoice"),
420
+ ("camembert", "TFCamembertForMultipleChoice"),
421
+ ("convbert", "TFConvBertForMultipleChoice"),
422
+ ("deberta-v2", "TFDebertaV2ForMultipleChoice"),
423
+ ("distilbert", "TFDistilBertForMultipleChoice"),
424
+ ("electra", "TFElectraForMultipleChoice"),
425
+ ("flaubert", "TFFlaubertForMultipleChoice"),
426
+ ("funnel", "TFFunnelForMultipleChoice"),
427
+ ("longformer", "TFLongformerForMultipleChoice"),
428
+ ("mobilebert", "TFMobileBertForMultipleChoice"),
429
+ ("mpnet", "TFMPNetForMultipleChoice"),
430
+ ("rembert", "TFRemBertForMultipleChoice"),
431
+ ("roberta", "TFRobertaForMultipleChoice"),
432
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormForMultipleChoice"),
433
+ ("roformer", "TFRoFormerForMultipleChoice"),
434
+ ("xlm", "TFXLMForMultipleChoice"),
435
+ ("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
436
+ ("xlnet", "TFXLNetForMultipleChoice"),
437
+ ]
438
+ )
439
+
440
+ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
441
+ [
442
+ ("bert", "TFBertForNextSentencePrediction"),
443
+ ("mobilebert", "TFMobileBertForNextSentencePrediction"),
444
+ ]
445
+ )
446
+ TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
447
+ [
448
+ ("sam", "TFSamModel"),
449
+ ]
450
+ )
451
+ TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
452
+ [
453
+ ("albert", "TFAlbertModel"),
454
+ ("bert", "TFBertModel"),
455
+ ("convbert", "TFConvBertModel"),
456
+ ("deberta", "TFDebertaModel"),
457
+ ("deberta-v2", "TFDebertaV2Model"),
458
+ ("distilbert", "TFDistilBertModel"),
459
+ ("electra", "TFElectraModel"),
460
+ ("flaubert", "TFFlaubertModel"),
461
+ ("longformer", "TFLongformerModel"),
462
+ ("mobilebert", "TFMobileBertModel"),
463
+ ("mt5", "TFMT5EncoderModel"),
464
+ ("rembert", "TFRemBertModel"),
465
+ ("roberta", "TFRobertaModel"),
466
+ ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
467
+ ("roformer", "TFRoFormerModel"),
468
+ ("t5", "TFT5EncoderModel"),
469
+ ("xlm", "TFXLMModel"),
470
+ ("xlm-roberta", "TFXLMRobertaModel"),
471
+ ]
472
+ )
473
+
474
+ TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
475
+ TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
476
+ TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
477
+ TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
478
+ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
479
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
480
+ )
481
+ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
482
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
483
+ )
484
+ TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
485
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
486
+ )
487
+ TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
488
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
489
+ )
490
+ TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
491
+ TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
492
+ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
493
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
494
+ )
495
+ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
496
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
497
+ )
498
+ TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
499
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
500
+ )
501
+ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
502
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
503
+ )
504
+ TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
505
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
506
+ )
507
+ TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
508
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
509
+ )
510
+ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
511
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
512
+ )
513
+ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
514
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
515
+ )
516
+ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
517
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
518
+ )
519
+ TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
520
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
521
+ )
522
+
523
+ TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
524
+ CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
525
+ )
526
+
527
+ TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
528
+
529
+
530
+ class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
531
+ _model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
532
+
533
+
534
+ class TFAutoModelForTextEncoding(_BaseAutoModelClass):
535
+ _model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING
536
+
537
+
538
+ class TFAutoModel(_BaseAutoModelClass):
539
+ _model_mapping = TF_MODEL_MAPPING
540
+
541
+
542
+ TFAutoModel = auto_class_update(TFAutoModel)
543
+
544
+
545
+ class TFAutoModelForAudioClassification(_BaseAutoModelClass):
546
+ _model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
547
+
548
+
549
+ TFAutoModelForAudioClassification = auto_class_update(
550
+ TFAutoModelForAudioClassification, head_doc="audio classification"
551
+ )
552
+
553
+
554
+ class TFAutoModelForPreTraining(_BaseAutoModelClass):
555
+ _model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
556
+
557
+
558
+ TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")
559
+
560
+
561
+ # Private on purpose, the public class will add the deprecation warnings.
562
+ class _TFAutoModelWithLMHead(_BaseAutoModelClass):
563
+ _model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING
564
+
565
+
566
+ _TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
567
+
568
+
569
+ class TFAutoModelForCausalLM(_BaseAutoModelClass):
570
+ _model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
571
+
572
+
573
+ TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
574
+
575
+
576
+ class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass):
577
+ _model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
578
+
579
+
580
+ TFAutoModelForMaskedImageModeling = auto_class_update(
581
+ TFAutoModelForMaskedImageModeling, head_doc="masked image modeling"
582
+ )
583
+
584
+
585
+ class TFAutoModelForImageClassification(_BaseAutoModelClass):
586
+ _model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
587
+
588
+
589
+ TFAutoModelForImageClassification = auto_class_update(
590
+ TFAutoModelForImageClassification, head_doc="image classification"
591
+ )
592
+
593
+
594
+ class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
595
+ _model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
596
+
597
+
598
+ TFAutoModelForZeroShotImageClassification = auto_class_update(
599
+ TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
600
+ )
601
+
602
+
603
+ class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
604
+ _model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
605
+
606
+
607
+ TFAutoModelForSemanticSegmentation = auto_class_update(
608
+ TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation"
609
+ )
610
+
611
+
612
+ class TFAutoModelForVision2Seq(_BaseAutoModelClass):
613
+ _model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
614
+
615
+
616
+ TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling")
617
+
618
+
619
+ class TFAutoModelForMaskedLM(_BaseAutoModelClass):
620
+ _model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
621
+
622
+
623
+ TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling")
624
+
625
+
626
+ class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
627
+ _model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
628
+
629
+
630
+ TFAutoModelForSeq2SeqLM = auto_class_update(
631
+ TFAutoModelForSeq2SeqLM,
632
+ head_doc="sequence-to-sequence language modeling",
633
+ checkpoint_for_example="google-t5/t5-base",
634
+ )
635
+
636
+
637
+ class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
638
+ _model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
639
+
640
+
641
+ TFAutoModelForSequenceClassification = auto_class_update(
642
+ TFAutoModelForSequenceClassification, head_doc="sequence classification"
643
+ )
644
+
645
+
646
+ class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
647
+ _model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
648
+
649
+
650
+ TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")
651
+
652
+
653
+ class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
654
+ _model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
655
+
656
+
657
+ TFAutoModelForDocumentQuestionAnswering = auto_class_update(
658
+ TFAutoModelForDocumentQuestionAnswering,
659
+ head_doc="document question answering",
660
+ checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
661
+ )
662
+
663
+
664
+ class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass):
665
+ _model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
666
+
667
+
668
+ TFAutoModelForTableQuestionAnswering = auto_class_update(
669
+ TFAutoModelForTableQuestionAnswering,
670
+ head_doc="table question answering",
671
+ checkpoint_for_example="google/tapas-base-finetuned-wtq",
672
+ )
673
+
674
+
675
+ class TFAutoModelForTokenClassification(_BaseAutoModelClass):
676
+ _model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
677
+
678
+
679
+ TFAutoModelForTokenClassification = auto_class_update(
680
+ TFAutoModelForTokenClassification, head_doc="token classification"
681
+ )
682
+
683
+
684
+ class TFAutoModelForMultipleChoice(_BaseAutoModelClass):
685
+ _model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
686
+
687
+
688
+ TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice")
689
+
690
+
691
+ class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
692
+ _model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
693
+
694
+
695
+ TFAutoModelForNextSentencePrediction = auto_class_update(
696
+ TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
697
+ )
698
+
699
+
700
+ class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
701
+ _model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
702
+
703
+
704
+ TFAutoModelForSpeechSeq2Seq = auto_class_update(
705
+ TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
706
+ )
707
+
708
+
709
+ class TFAutoModelWithLMHead(_TFAutoModelWithLMHead):
710
+ @classmethod
711
+ def from_config(cls, config):
712
+ warnings.warn(
713
+ "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
714
+ " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
715
+ " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
716
+ FutureWarning,
717
+ )
718
+ return super().from_config(config)
719
+
720
+ @classmethod
721
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
722
+ warnings.warn(
723
+ "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
724
+ " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
725
+ " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
726
+ FutureWarning,
727
+ )
728
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
729
+
730
+
731
+ __all__ = [
732
+ "TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
733
+ "TF_MODEL_FOR_CAUSAL_LM_MAPPING",
734
+ "TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
735
+ "TF_MODEL_FOR_MASK_GENERATION_MAPPING",
736
+ "TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
737
+ "TF_MODEL_FOR_MASKED_LM_MAPPING",
738
+ "TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
739
+ "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
740
+ "TF_MODEL_FOR_PRETRAINING_MAPPING",
741
+ "TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
742
+ "TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING",
743
+ "TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
744
+ "TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
745
+ "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
746
+ "TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
747
+ "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
748
+ "TF_MODEL_FOR_TEXT_ENCODING_MAPPING",
749
+ "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
750
+ "TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
751
+ "TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
752
+ "TF_MODEL_MAPPING",
753
+ "TF_MODEL_WITH_LM_HEAD_MAPPING",
754
+ "TFAutoModel",
755
+ "TFAutoModelForAudioClassification",
756
+ "TFAutoModelForCausalLM",
757
+ "TFAutoModelForImageClassification",
758
+ "TFAutoModelForMaskedImageModeling",
759
+ "TFAutoModelForMaskedLM",
760
+ "TFAutoModelForMaskGeneration",
761
+ "TFAutoModelForMultipleChoice",
762
+ "TFAutoModelForNextSentencePrediction",
763
+ "TFAutoModelForPreTraining",
764
+ "TFAutoModelForDocumentQuestionAnswering",
765
+ "TFAutoModelForQuestionAnswering",
766
+ "TFAutoModelForSemanticSegmentation",
767
+ "TFAutoModelForSeq2SeqLM",
768
+ "TFAutoModelForSequenceClassification",
769
+ "TFAutoModelForSpeechSeq2Seq",
770
+ "TFAutoModelForTableQuestionAnswering",
771
+ "TFAutoModelForTextEncoding",
772
+ "TFAutoModelForTokenClassification",
773
+ "TFAutoModelForVision2Seq",
774
+ "TFAutoModelForZeroShotImageClassification",
775
+ "TFAutoModelWithLMHead",
776
+ ]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/processing_auto.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """AutoProcessor class."""
16
+
17
+ import importlib
18
+ import inspect
19
+ import json
20
+ import warnings
21
+ from collections import OrderedDict
22
+
23
+ # Build the list of all feature extractors
24
+ from ...configuration_utils import PretrainedConfig
25
+ from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
26
+ from ...feature_extraction_utils import FeatureExtractionMixin
27
+ from ...image_processing_utils import ImageProcessingMixin
28
+ from ...processing_utils import ProcessorMixin
29
+ from ...tokenization_utils import TOKENIZER_CONFIG_FILE
30
+ from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, VIDEO_PROCESSOR_NAME, cached_file, logging
31
+ from ...video_processing_utils import BaseVideoProcessor
32
+ from .auto_factory import _LazyAutoMapping
33
+ from .configuration_auto import (
34
+ CONFIG_MAPPING_NAMES,
35
+ AutoConfig,
36
+ model_type_to_module_name,
37
+ replace_list_option_in_docstrings,
38
+ )
39
+ from .feature_extraction_auto import AutoFeatureExtractor
40
+ from .image_processing_auto import AutoImageProcessor
41
+ from .tokenization_auto import AutoTokenizer
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ PROCESSOR_MAPPING_NAMES = OrderedDict(
47
+ [
48
+ ("aimv2", "CLIPProcessor"),
49
+ ("align", "AlignProcessor"),
50
+ ("altclip", "AltCLIPProcessor"),
51
+ ("aria", "AriaProcessor"),
52
+ ("aya_vision", "AyaVisionProcessor"),
53
+ ("bark", "BarkProcessor"),
54
+ ("blip", "BlipProcessor"),
55
+ ("blip-2", "Blip2Processor"),
56
+ ("bridgetower", "BridgeTowerProcessor"),
57
+ ("chameleon", "ChameleonProcessor"),
58
+ ("chinese_clip", "ChineseCLIPProcessor"),
59
+ ("clap", "ClapProcessor"),
60
+ ("clip", "CLIPProcessor"),
61
+ ("clipseg", "CLIPSegProcessor"),
62
+ ("clvp", "ClvpProcessor"),
63
+ ("cohere2_vision", "Cohere2VisionProcessor"),
64
+ ("colpali", "ColPaliProcessor"),
65
+ ("colqwen2", "ColQwen2Processor"),
66
+ ("deepseek_vl", "DeepseekVLProcessor"),
67
+ ("deepseek_vl_hybrid", "DeepseekVLHybridProcessor"),
68
+ ("dia", "DiaProcessor"),
69
+ ("edgetam", "Sam2Processor"),
70
+ ("emu3", "Emu3Processor"),
71
+ ("evolla", "EvollaProcessor"),
72
+ ("flava", "FlavaProcessor"),
73
+ ("florence2", "Florence2Processor"),
74
+ ("fuyu", "FuyuProcessor"),
75
+ ("gemma3", "Gemma3Processor"),
76
+ ("gemma3n", "Gemma3nProcessor"),
77
+ ("git", "GitProcessor"),
78
+ ("glm4v", "Glm4vProcessor"),
79
+ ("glm4v_moe", "Glm4vProcessor"),
80
+ ("got_ocr2", "GotOcr2Processor"),
81
+ ("granite_speech", "GraniteSpeechProcessor"),
82
+ ("grounding-dino", "GroundingDinoProcessor"),
83
+ ("groupvit", "CLIPProcessor"),
84
+ ("hubert", "Wav2Vec2Processor"),
85
+ ("idefics", "IdeficsProcessor"),
86
+ ("idefics2", "Idefics2Processor"),
87
+ ("idefics3", "Idefics3Processor"),
88
+ ("instructblip", "InstructBlipProcessor"),
89
+ ("instructblipvideo", "InstructBlipVideoProcessor"),
90
+ ("internvl", "InternVLProcessor"),
91
+ ("janus", "JanusProcessor"),
92
+ ("kosmos-2", "Kosmos2Processor"),
93
+ ("kosmos-2.5", "Kosmos2_5Processor"),
94
+ ("kyutai_speech_to_text", "KyutaiSpeechToTextProcessor"),
95
+ ("layoutlmv2", "LayoutLMv2Processor"),
96
+ ("layoutlmv3", "LayoutLMv3Processor"),
97
+ ("lfm2_vl", "Lfm2VlProcessor"),
98
+ ("llama4", "Llama4Processor"),
99
+ ("llava", "LlavaProcessor"),
100
+ ("llava_next", "LlavaNextProcessor"),
101
+ ("llava_next_video", "LlavaNextVideoProcessor"),
102
+ ("llava_onevision", "LlavaOnevisionProcessor"),
103
+ ("markuplm", "MarkupLMProcessor"),
104
+ ("mctct", "MCTCTProcessor"),
105
+ ("metaclip_2", "CLIPProcessor"),
106
+ ("mgp-str", "MgpstrProcessor"),
107
+ ("mistral3", "PixtralProcessor"),
108
+ ("mllama", "MllamaProcessor"),
109
+ ("mm-grounding-dino", "GroundingDinoProcessor"),
110
+ ("moonshine", "Wav2Vec2Processor"),
111
+ ("oneformer", "OneFormerProcessor"),
112
+ ("ovis2", "Ovis2Processor"),
113
+ ("owlv2", "Owlv2Processor"),
114
+ ("owlvit", "OwlViTProcessor"),
115
+ ("paligemma", "PaliGemmaProcessor"),
116
+ ("perception_lm", "PerceptionLMProcessor"),
117
+ ("phi4_multimodal", "Phi4MultimodalProcessor"),
118
+ ("pix2struct", "Pix2StructProcessor"),
119
+ ("pixtral", "PixtralProcessor"),
120
+ ("pop2piano", "Pop2PianoProcessor"),
121
+ ("qwen2_5_omni", "Qwen2_5OmniProcessor"),
122
+ ("qwen2_5_vl", "Qwen2_5_VLProcessor"),
123
+ ("qwen2_audio", "Qwen2AudioProcessor"),
124
+ ("qwen2_vl", "Qwen2VLProcessor"),
125
+ ("qwen3_omni_moe", "Qwen3OmniMoeProcessor"),
126
+ ("qwen3_vl", "Qwen3VLProcessor"),
127
+ ("qwen3_vl_moe", "Qwen3VLProcessor"),
128
+ ("sam", "SamProcessor"),
129
+ ("sam2", "Sam2Processor"),
130
+ ("sam_hq", "SamHQProcessor"),
131
+ ("seamless_m4t", "SeamlessM4TProcessor"),
132
+ ("sew", "Wav2Vec2Processor"),
133
+ ("sew-d", "Wav2Vec2Processor"),
134
+ ("shieldgemma2", "ShieldGemma2Processor"),
135
+ ("siglip", "SiglipProcessor"),
136
+ ("siglip2", "Siglip2Processor"),
137
+ ("smolvlm", "SmolVLMProcessor"),
138
+ ("speech_to_text", "Speech2TextProcessor"),
139
+ ("speech_to_text_2", "Speech2Text2Processor"),
140
+ ("speecht5", "SpeechT5Processor"),
141
+ ("trocr", "TrOCRProcessor"),
142
+ ("tvlt", "TvltProcessor"),
143
+ ("tvp", "TvpProcessor"),
144
+ ("udop", "UdopProcessor"),
145
+ ("unispeech", "Wav2Vec2Processor"),
146
+ ("unispeech-sat", "Wav2Vec2Processor"),
147
+ ("video_llava", "VideoLlavaProcessor"),
148
+ ("vilt", "ViltProcessor"),
149
+ ("vipllava", "LlavaProcessor"),
150
+ ("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
151
+ ("voxtral", "VoxtralProcessor"),
152
+ ("wav2vec2", "Wav2Vec2Processor"),
153
+ ("wav2vec2-bert", "Wav2Vec2Processor"),
154
+ ("wav2vec2-conformer", "Wav2Vec2Processor"),
155
+ ("wavlm", "Wav2Vec2Processor"),
156
+ ("whisper", "WhisperProcessor"),
157
+ ("xclip", "XCLIPProcessor"),
158
+ ]
159
+ )
160
+
161
+ PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, PROCESSOR_MAPPING_NAMES)
162
+
163
+
164
+ def processor_class_from_name(class_name: str):
165
+ for module_name, processors in PROCESSOR_MAPPING_NAMES.items():
166
+ if class_name in processors:
167
+ module_name = model_type_to_module_name(module_name)
168
+
169
+ module = importlib.import_module(f".{module_name}", "transformers.models")
170
+ try:
171
+ return getattr(module, class_name)
172
+ except AttributeError:
173
+ continue
174
+
175
+ for processor in PROCESSOR_MAPPING._extra_content.values():
176
+ if getattr(processor, "__name__", None) == class_name:
177
+ return processor
178
+
179
+ # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
180
+ # init and we return the proper dummy to get an appropriate error message.
181
+ main_module = importlib.import_module("transformers")
182
+ if hasattr(main_module, class_name):
183
+ return getattr(main_module, class_name)
184
+
185
+ return None
186
+
187
+
188
+ class AutoProcessor:
189
+ r"""
190
+ This is a generic processor class that will be instantiated as one of the processor classes of the library when
191
+ created with the [`AutoProcessor.from_pretrained`] class method.
192
+
193
+ This class cannot be instantiated directly using `__init__()` (throws an error).
194
+ """
195
+
196
+ def __init__(self):
197
+ raise OSError(
198
+ "AutoProcessor is designed to be instantiated "
199
+ "using the `AutoProcessor.from_pretrained(pretrained_model_name_or_path)` method."
200
+ )
201
+
202
+ @classmethod
203
+ @replace_list_option_in_docstrings(PROCESSOR_MAPPING_NAMES)
204
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
205
+ r"""
206
+ Instantiate one of the processor classes of the library from a pretrained model vocabulary.
207
+
208
+ The processor class to instantiate is selected based on the `model_type` property of the config object (either
209
+ passed as an argument or loaded from `pretrained_model_name_or_path` if possible):
210
+
211
+ List options
212
+
213
+ Params:
214
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
215
+ This can be either:
216
+
217
+ - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
218
+ huggingface.co.
219
+ - a path to a *directory* containing a processor files saved using the `save_pretrained()` method,
220
+ e.g., `./my_model_directory/`.
221
+ cache_dir (`str` or `os.PathLike`, *optional*):
222
+ Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
223
+ standard cache should not be used.
224
+ force_download (`bool`, *optional*, defaults to `False`):
225
+ Whether or not to force to (re-)download the feature extractor files and override the cached versions
226
+ if they exist.
227
+ resume_download:
228
+ Deprecated and ignored. All downloads are now resumed by default when possible.
229
+ Will be removed in v5 of Transformers.
230
+ proxies (`dict[str, str]`, *optional*):
231
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
232
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
233
+ token (`str` or *bool*, *optional*):
234
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
235
+ when running `hf auth login` (stored in `~/.huggingface`).
236
+ revision (`str`, *optional*, defaults to `"main"`):
237
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
238
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
239
+ identifier allowed by git.
240
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
241
+ If `False`, then this function returns just the final feature extractor object. If `True`, then this
242
+ functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
243
+ consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
244
+ `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
245
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
246
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
247
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
248
+ execute code present on the Hub on your local machine.
249
+ kwargs (`dict[str, Any]`, *optional*):
250
+ The values in kwargs of any keys which are feature extractor attributes will be used to override the
251
+ loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
252
+ controlled by the `return_unused_kwargs` keyword parameter.
253
+
254
+ <Tip>
255
+
256
+ Passing `token=True` is required when you want to use a private model.
257
+
258
+ </Tip>
259
+
260
+ Examples:
261
+
262
+ ```python
263
+ >>> from transformers import AutoProcessor
264
+
265
+ >>> # Download processor from huggingface.co and cache.
266
+ >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
267
+
268
+ >>> # If processor files are in a directory (e.g. processor was saved using *save_pretrained('./test/saved_model/')*)
269
+ >>> # processor = AutoProcessor.from_pretrained("./test/saved_model/")
270
+ ```"""
271
+ use_auth_token = kwargs.pop("use_auth_token", None)
272
+ if use_auth_token is not None:
273
+ warnings.warn(
274
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
275
+ FutureWarning,
276
+ )
277
+ if kwargs.get("token") is not None:
278
+ raise ValueError(
279
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
280
+ )
281
+ kwargs["token"] = use_auth_token
282
+
283
+ config = kwargs.pop("config", None)
284
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
285
+ kwargs["_from_auto"] = True
286
+
287
+ processor_class = None
288
+ processor_auto_map = None
289
+
290
+ # First, let's see if we have a processor or preprocessor config.
291
+ # Filter the kwargs for `cached_file`.
292
+ cached_file_kwargs = {key: kwargs[key] for key in inspect.signature(cached_file).parameters if key in kwargs}
293
+ # We don't want to raise
294
+ cached_file_kwargs.update(
295
+ {
296
+ "_raise_exceptions_for_gated_repo": False,
297
+ "_raise_exceptions_for_missing_entries": False,
298
+ "_raise_exceptions_for_connection_errors": False,
299
+ }
300
+ )
301
+
302
+ # Let's start by checking whether the processor class is saved in a processor config
303
+ processor_config_file = cached_file(pretrained_model_name_or_path, PROCESSOR_NAME, **cached_file_kwargs)
304
+ if processor_config_file is not None:
305
+ config_dict, _ = ProcessorMixin.get_processor_dict(pretrained_model_name_or_path, **kwargs)
306
+ processor_class = config_dict.get("processor_class", None)
307
+ if "AutoProcessor" in config_dict.get("auto_map", {}):
308
+ processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
309
+
310
+ if processor_class is None:
311
+ # If not found, let's check whether the processor class is saved in an image processor config
312
+ preprocessor_config_file = cached_file(
313
+ pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
314
+ )
315
+ if preprocessor_config_file is not None:
316
+ config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
317
+ processor_class = config_dict.get("processor_class", None)
318
+ if "AutoProcessor" in config_dict.get("auto_map", {}):
319
+ processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
320
+
321
+ # Saved as video processor
322
+ if preprocessor_config_file is None:
323
+ preprocessor_config_file = cached_file(
324
+ pretrained_model_name_or_path, VIDEO_PROCESSOR_NAME, **cached_file_kwargs
325
+ )
326
+ if preprocessor_config_file is not None:
327
+ config_dict, _ = BaseVideoProcessor.get_video_processor_dict(
328
+ pretrained_model_name_or_path, **kwargs
329
+ )
330
+ processor_class = config_dict.get("processor_class", None)
331
+ if "AutoProcessor" in config_dict.get("auto_map", {}):
332
+ processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
333
+
334
+ # Saved as feature extractor
335
+ if preprocessor_config_file is None:
336
+ preprocessor_config_file = cached_file(
337
+ pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
338
+ )
339
+ if preprocessor_config_file is not None and processor_class is None:
340
+ config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(
341
+ pretrained_model_name_or_path, **kwargs
342
+ )
343
+ processor_class = config_dict.get("processor_class", None)
344
+ if "AutoProcessor" in config_dict.get("auto_map", {}):
345
+ processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
346
+
347
+ if processor_class is None:
348
+ # Next, let's check whether the processor class is saved in a tokenizer
349
+ tokenizer_config_file = cached_file(
350
+ pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **cached_file_kwargs
351
+ )
352
+ if tokenizer_config_file is not None:
353
+ with open(tokenizer_config_file, encoding="utf-8") as reader:
354
+ config_dict = json.load(reader)
355
+
356
+ processor_class = config_dict.get("processor_class", None)
357
+ if "AutoProcessor" in config_dict.get("auto_map", {}):
358
+ processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
359
+
360
+ if processor_class is None:
361
+ # Otherwise, load config, if it can be loaded.
362
+ if not isinstance(config, PretrainedConfig):
363
+ config = AutoConfig.from_pretrained(
364
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
365
+ )
366
+
367
+ # And check if the config contains the processor class.
368
+ processor_class = getattr(config, "processor_class", None)
369
+ if hasattr(config, "auto_map") and "AutoProcessor" in config.auto_map:
370
+ processor_auto_map = config.auto_map["AutoProcessor"]
371
+
372
+ if processor_class is not None:
373
+ processor_class = processor_class_from_name(processor_class)
374
+
375
+ has_remote_code = processor_auto_map is not None
376
+ has_local_code = processor_class is not None or type(config) in PROCESSOR_MAPPING
377
+ if has_remote_code:
378
+ if "--" in processor_auto_map:
379
+ upstream_repo = processor_auto_map.split("--")[0]
380
+ else:
381
+ upstream_repo = None
382
+ trust_remote_code = resolve_trust_remote_code(
383
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
384
+ )
385
+
386
+ if has_remote_code and trust_remote_code:
387
+ processor_class = get_class_from_dynamic_module(
388
+ processor_auto_map, pretrained_model_name_or_path, **kwargs
389
+ )
390
+ _ = kwargs.pop("code_revision", None)
391
+ processor_class.register_for_auto_class()
392
+ return processor_class.from_pretrained(
393
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
394
+ )
395
+ elif processor_class is not None:
396
+ return processor_class.from_pretrained(
397
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
398
+ )
399
+ # Last try: we use the PROCESSOR_MAPPING.
400
+ elif type(config) in PROCESSOR_MAPPING:
401
+ return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs)
402
+
403
+ # At this stage, there doesn't seem to be a `Processor` class available for this model, so let's try a
404
+ # tokenizer.
405
+ try:
406
+ return AutoTokenizer.from_pretrained(
407
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
408
+ )
409
+ except Exception:
410
+ try:
411
+ return AutoImageProcessor.from_pretrained(
412
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
413
+ )
414
+ except Exception:
415
+ pass
416
+
417
+ try:
418
+ return AutoFeatureExtractor.from_pretrained(
419
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
420
+ )
421
+ except Exception:
422
+ pass
423
+
424
+ raise ValueError(
425
+ f"Unrecognized processing class in {pretrained_model_name_or_path}. Can't instantiate a processor, a "
426
+ "tokenizer, an image processor or a feature extractor for this model. Make sure the repository contains "
427
+ "the files of at least one of those processing classes."
428
+ )
429
+
430
+ @staticmethod
431
+ def register(config_class, processor_class, exist_ok=False):
432
+ """
433
+ Register a new processor for this class.
434
+
435
+ Args:
436
+ config_class ([`PretrainedConfig`]):
437
+ The configuration corresponding to the model to register.
438
+ processor_class ([`ProcessorMixin`]): The processor to register.
439
+ """
440
+ PROCESSOR_MAPPING.register(config_class, processor_class, exist_ok=exist_ok)
441
+
442
+
443
+ __all__ = ["PROCESSOR_MAPPING", "AutoProcessor"]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/tokenization_auto.py ADDED
@@ -0,0 +1,1235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Auto Tokenizer class."""
16
+
17
+ import importlib
18
+ import json
19
+ import os
20
+ import warnings
21
+ from collections import OrderedDict
22
+ from typing import Any, Optional, Union
23
+
24
+ from transformers.utils.import_utils import is_mistral_common_available
25
+
26
+ from ...configuration_utils import PretrainedConfig
27
+ from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
28
+ from ...modeling_gguf_pytorch_utils import load_gguf_checkpoint
29
+ from ...tokenization_utils import PreTrainedTokenizer
30
+ from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
31
+ from ...utils import (
32
+ cached_file,
33
+ extract_commit_hash,
34
+ is_g2p_en_available,
35
+ is_sentencepiece_available,
36
+ is_tokenizers_available,
37
+ logging,
38
+ )
39
+ from ..encoder_decoder import EncoderDecoderConfig
40
+ from .auto_factory import _LazyAutoMapping
41
+ from .configuration_auto import (
42
+ CONFIG_MAPPING_NAMES,
43
+ AutoConfig,
44
+ config_class_to_model_type,
45
+ model_type_to_module_name,
46
+ replace_list_option_in_docstrings,
47
+ )
48
+
49
+
50
+ if is_tokenizers_available():
51
+ from ...tokenization_utils_fast import PreTrainedTokenizerFast
52
+ else:
53
+ PreTrainedTokenizerFast = None
54
+
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ # Explicit rather than inferred generics to significantly improves completion suggestion performance for language servers.
59
+ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
60
+ [
61
+ (
62
+ "aimv2",
63
+ (
64
+ "CLIPTokenizer",
65
+ "CLIPTokenizerFast" if is_tokenizers_available() else None,
66
+ ),
67
+ ),
68
+ (
69
+ "albert",
70
+ (
71
+ "AlbertTokenizer" if is_sentencepiece_available() else None,
72
+ "AlbertTokenizerFast" if is_tokenizers_available() else None,
73
+ ),
74
+ ),
75
+ ("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
76
+ ("arcee", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
77
+ ("aria", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
78
+ ("aya_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
79
+ ("bark", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
80
+ ("bart", ("BartTokenizer", "BartTokenizerFast")),
81
+ (
82
+ "barthez",
83
+ (
84
+ "BarthezTokenizer" if is_sentencepiece_available() else None,
85
+ "BarthezTokenizerFast" if is_tokenizers_available() else None,
86
+ ),
87
+ ),
88
+ ("bartpho", ("BartphoTokenizer", None)),
89
+ ("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
90
+ ("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
91
+ ("bert-japanese", ("BertJapaneseTokenizer", None)),
92
+ ("bertweet", ("BertweetTokenizer", None)),
93
+ (
94
+ "big_bird",
95
+ (
96
+ "BigBirdTokenizer" if is_sentencepiece_available() else None,
97
+ "BigBirdTokenizerFast" if is_tokenizers_available() else None,
98
+ ),
99
+ ),
100
+ ("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
101
+ ("biogpt", ("BioGptTokenizer", None)),
102
+ ("bitnet", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
103
+ ("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
104
+ ("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
105
+ ("blip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
106
+ ("blip-2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
107
+ ("bloom", (None, "BloomTokenizerFast" if is_tokenizers_available() else None)),
108
+ ("blt", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
109
+ ("bridgetower", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
110
+ ("bros", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
111
+ ("byt5", ("ByT5Tokenizer", None)),
112
+ (
113
+ "camembert",
114
+ (
115
+ "CamembertTokenizer" if is_sentencepiece_available() else None,
116
+ "CamembertTokenizerFast" if is_tokenizers_available() else None,
117
+ ),
118
+ ),
119
+ ("canine", ("CanineTokenizer", None)),
120
+ (
121
+ "chameleon",
122
+ (
123
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
124
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
125
+ ),
126
+ ),
127
+ ("chinese_clip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
128
+ (
129
+ "clap",
130
+ (
131
+ "RobertaTokenizer",
132
+ "RobertaTokenizerFast" if is_tokenizers_available() else None,
133
+ ),
134
+ ),
135
+ (
136
+ "clip",
137
+ (
138
+ "CLIPTokenizer",
139
+ "CLIPTokenizerFast" if is_tokenizers_available() else None,
140
+ ),
141
+ ),
142
+ (
143
+ "clipseg",
144
+ (
145
+ "CLIPTokenizer",
146
+ "CLIPTokenizerFast" if is_tokenizers_available() else None,
147
+ ),
148
+ ),
149
+ ("clvp", ("ClvpTokenizer", None)),
150
+ (
151
+ "code_llama",
152
+ (
153
+ "CodeLlamaTokenizer" if is_sentencepiece_available() else None,
154
+ "CodeLlamaTokenizerFast" if is_tokenizers_available() else None,
155
+ ),
156
+ ),
157
+ ("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
158
+ ("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
159
+ ("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
160
+ ("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
161
+ ("colqwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
162
+ ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
163
+ (
164
+ "cpm",
165
+ (
166
+ "CpmTokenizer" if is_sentencepiece_available() else None,
167
+ "CpmTokenizerFast" if is_tokenizers_available() else None,
168
+ ),
169
+ ),
170
+ ("cpmant", ("CpmAntTokenizer", None)),
171
+ ("csm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
172
+ ("ctrl", ("CTRLTokenizer", None)),
173
+ ("data2vec-audio", ("Wav2Vec2CTCTokenizer", None)),
174
+ ("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
175
+ ("dbrx", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
176
+ ("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
177
+ (
178
+ "deberta-v2",
179
+ (
180
+ "DebertaV2Tokenizer" if is_sentencepiece_available() else None,
181
+ "DebertaV2TokenizerFast" if is_tokenizers_available() else None,
182
+ ),
183
+ ),
184
+ (
185
+ "deepseek_v2",
186
+ (
187
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
188
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
189
+ ),
190
+ ),
191
+ (
192
+ "deepseek_v3",
193
+ (
194
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
195
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
196
+ ),
197
+ ),
198
+ (
199
+ "deepseek_vl",
200
+ (
201
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
202
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
203
+ ),
204
+ ),
205
+ (
206
+ "deepseek_vl_hybrid",
207
+ (
208
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
209
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
210
+ ),
211
+ ),
212
+ ("dia", ("DiaTokenizer", None)),
213
+ (
214
+ "diffllama",
215
+ (
216
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
217
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
218
+ ),
219
+ ),
220
+ ("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
221
+ (
222
+ "dpr",
223
+ (
224
+ "DPRQuestionEncoderTokenizer",
225
+ "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
226
+ ),
227
+ ),
228
+ ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
229
+ ("emu3", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
230
+ ("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
231
+ ("ernie4_5", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
232
+ ("ernie4_5_moe", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
233
+ ("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
234
+ ("esm", ("EsmTokenizer", None)),
235
+ (
236
+ "exaone4",
237
+ (
238
+ "GPT2Tokenizer" if is_tokenizers_available() else None,
239
+ "GPT2TokenizerFast" if is_tokenizers_available() else None,
240
+ ),
241
+ ),
242
+ ("falcon", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
243
+ ("falcon_mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
244
+ (
245
+ "fastspeech2_conformer",
246
+ ("FastSpeech2ConformerTokenizer" if is_g2p_en_available() else None, None),
247
+ ),
248
+ ("flaubert", ("FlaubertTokenizer", None)),
249
+ ("flex_olmo", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
250
+ ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
251
+ ("fsmt", ("FSMTTokenizer", None)),
252
+ ("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
253
+ (
254
+ "gemma",
255
+ (
256
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
257
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
258
+ ),
259
+ ),
260
+ (
261
+ "gemma2",
262
+ (
263
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
264
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
265
+ ),
266
+ ),
267
+ (
268
+ "gemma3",
269
+ (
270
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
271
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
272
+ ),
273
+ ),
274
+ (
275
+ "gemma3_text",
276
+ (
277
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
278
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
279
+ ),
280
+ ),
281
+ (
282
+ "gemma3n",
283
+ (
284
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
285
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
286
+ ),
287
+ ),
288
+ (
289
+ "gemma3n_text",
290
+ (
291
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
292
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
293
+ ),
294
+ ),
295
+ ("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
296
+ ("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
297
+ ("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
298
+ ("glm4_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
299
+ ("glm4v", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
300
+ ("glm4v_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
301
+ ("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
302
+ ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
303
+ ("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
304
+ ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
305
+ ("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
306
+ ("gpt_neox_japanese", ("GPTNeoXJapaneseTokenizer", None)),
307
+ ("gpt_oss", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
308
+ ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
309
+ ("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)),
310
+ ("granite", ("GPT2Tokenizer", None)),
311
+ ("granitemoe", ("GPT2Tokenizer", None)),
312
+ ("granitemoehybrid", ("GPT2Tokenizer", None)),
313
+ ("granitemoeshared", ("GPT2Tokenizer", None)),
314
+ ("grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
315
+ ("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
316
+ ("helium", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
317
+ ("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
318
+ ("hubert", ("Wav2Vec2CTCTokenizer", None)),
319
+ ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
320
+ ("idefics", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
321
+ ("idefics2", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
322
+ ("idefics3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
323
+ ("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
324
+ ("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
325
+ ("internvl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
326
+ (
327
+ "jamba",
328
+ (
329
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
330
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
331
+ ),
332
+ ),
333
+ ("janus", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
334
+ (
335
+ "jetmoe",
336
+ (
337
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
338
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
339
+ ),
340
+ ),
341
+ ("jukebox", ("JukeboxTokenizer", None)),
342
+ (
343
+ "kosmos-2",
344
+ (
345
+ "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
346
+ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
347
+ ),
348
+ ),
349
+ ("kosmos-2.5", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
350
+ ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
351
+ ("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
352
+ ("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
353
+ ("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
354
+ ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
355
+ ("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
356
+ (
357
+ "llama",
358
+ (
359
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
360
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
361
+ ),
362
+ ),
363
+ (
364
+ "llama4",
365
+ (
366
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
367
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
368
+ ),
369
+ ),
370
+ (
371
+ "llama4_text",
372
+ (
373
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
374
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
375
+ ),
376
+ ),
377
+ ("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
378
+ ("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
379
+ ("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
380
+ ("llava_onevision", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
381
+ ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
382
+ (
383
+ "longt5",
384
+ (
385
+ "T5Tokenizer" if is_sentencepiece_available() else None,
386
+ "T5TokenizerFast" if is_tokenizers_available() else None,
387
+ ),
388
+ ),
389
+ ("luke", ("LukeTokenizer", None)),
390
+ ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
391
+ ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
392
+ ("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
393
+ ("mamba2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
394
+ ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
395
+ (
396
+ "mbart",
397
+ (
398
+ "MBartTokenizer" if is_sentencepiece_available() else None,
399
+ "MBartTokenizerFast" if is_tokenizers_available() else None,
400
+ ),
401
+ ),
402
+ (
403
+ "mbart50",
404
+ (
405
+ "MBart50Tokenizer" if is_sentencepiece_available() else None,
406
+ "MBart50TokenizerFast" if is_tokenizers_available() else None,
407
+ ),
408
+ ),
409
+ ("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
410
+ ("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
411
+ (
412
+ "metaclip_2",
413
+ (
414
+ "XLMRobertaTokenizer",
415
+ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
416
+ ),
417
+ ),
418
+ ("mgp-str", ("MgpstrTokenizer", None)),
419
+ (
420
+ "minimax",
421
+ (
422
+ "GPT2Tokenizer" if is_sentencepiece_available() else None,
423
+ "GPT2TokenizerFast" if is_tokenizers_available() else None,
424
+ ),
425
+ ),
426
+ (
427
+ "ministral",
428
+ (
429
+ "MistralCommonTokenizer"
430
+ if is_mistral_common_available()
431
+ else ("LlamaTokenizer" if is_sentencepiece_available() else None),
432
+ "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
433
+ ),
434
+ ),
435
+ (
436
+ "mistral",
437
+ (
438
+ "MistralCommonTokenizer"
439
+ if is_mistral_common_available()
440
+ else ("LlamaTokenizer" if is_sentencepiece_available() else None),
441
+ "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
442
+ ),
443
+ ),
444
+ (
445
+ "mistral3",
446
+ (
447
+ "MistralCommonTokenizer"
448
+ if is_mistral_common_available()
449
+ else ("LlamaTokenizer" if is_sentencepiece_available() else None),
450
+ "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
451
+ ),
452
+ ),
453
+ (
454
+ "mixtral",
455
+ (
456
+ "MistralCommonTokenizer"
457
+ if is_mistral_common_available()
458
+ else ("LlamaTokenizer" if is_sentencepiece_available() else None),
459
+ "LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
460
+ ),
461
+ ),
462
+ ("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
463
+ ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
464
+ ("mm-grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
465
+ ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
466
+ ("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
467
+ ("moonshine", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
468
+ ("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
469
+ ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
470
+ ("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
471
+ ("mra", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
472
+ (
473
+ "mt5",
474
+ (
475
+ "MT5Tokenizer" if is_sentencepiece_available() else None,
476
+ "MT5TokenizerFast" if is_tokenizers_available() else None,
477
+ ),
478
+ ),
479
+ ("musicgen", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
480
+ ("musicgen_melody", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
481
+ ("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)),
482
+ ("myt5", ("MyT5Tokenizer", None)),
483
+ ("nemotron", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
484
+ ("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
485
+ (
486
+ "nllb",
487
+ (
488
+ "NllbTokenizer" if is_sentencepiece_available() else None,
489
+ "NllbTokenizerFast" if is_tokenizers_available() else None,
490
+ ),
491
+ ),
492
+ (
493
+ "nllb-moe",
494
+ (
495
+ "NllbTokenizer" if is_sentencepiece_available() else None,
496
+ "NllbTokenizerFast" if is_tokenizers_available() else None,
497
+ ),
498
+ ),
499
+ (
500
+ "nystromformer",
501
+ (
502
+ "AlbertTokenizer" if is_sentencepiece_available() else None,
503
+ "AlbertTokenizerFast" if is_tokenizers_available() else None,
504
+ ),
505
+ ),
506
+ ("olmo", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
507
+ ("olmo2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
508
+ ("olmo3", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
509
+ ("olmoe", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
510
+ (
511
+ "omdet-turbo",
512
+ ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None),
513
+ ),
514
+ ("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
515
+ (
516
+ "openai-gpt",
517
+ ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None),
518
+ ),
519
+ ("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
520
+ ("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
521
+ ("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
522
+ ("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
523
+ ("parakeet", ("ParakeetCTCTokenizer", None)),
524
+ (
525
+ "pegasus",
526
+ (
527
+ "PegasusTokenizer" if is_sentencepiece_available() else None,
528
+ "PegasusTokenizerFast" if is_tokenizers_available() else None,
529
+ ),
530
+ ),
531
+ (
532
+ "pegasus_x",
533
+ (
534
+ "PegasusTokenizer" if is_sentencepiece_available() else None,
535
+ "PegasusTokenizerFast" if is_tokenizers_available() else None,
536
+ ),
537
+ ),
538
+ (
539
+ "perceiver",
540
+ (
541
+ "PerceiverTokenizer",
542
+ None,
543
+ ),
544
+ ),
545
+ (
546
+ "persimmon",
547
+ (
548
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
549
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
550
+ ),
551
+ ),
552
+ ("phi", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
553
+ ("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
554
+ ("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
555
+ ("phobert", ("PhobertTokenizer", None)),
556
+ ("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
557
+ (
558
+ "pixtral",
559
+ (
560
+ None,
561
+ "MistralCommonTokenizer"
562
+ if is_mistral_common_available()
563
+ else ("PreTrainedTokenizerFast" if is_tokenizers_available() else None),
564
+ ),
565
+ ),
566
+ ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
567
+ ("prophetnet", ("ProphetNetTokenizer", None)),
568
+ ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
569
+ (
570
+ "qwen2",
571
+ (
572
+ "Qwen2Tokenizer",
573
+ "Qwen2TokenizerFast" if is_tokenizers_available() else None,
574
+ ),
575
+ ),
576
+ ("qwen2_5_omni", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
577
+ ("qwen2_5_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
578
+ ("qwen2_audio", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
579
+ (
580
+ "qwen2_moe",
581
+ (
582
+ "Qwen2Tokenizer",
583
+ "Qwen2TokenizerFast" if is_tokenizers_available() else None,
584
+ ),
585
+ ),
586
+ ("qwen2_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
587
+ (
588
+ "qwen3",
589
+ (
590
+ "Qwen2Tokenizer",
591
+ "Qwen2TokenizerFast" if is_tokenizers_available() else None,
592
+ ),
593
+ ),
594
+ (
595
+ "qwen3_moe",
596
+ (
597
+ "Qwen2Tokenizer",
598
+ "Qwen2TokenizerFast" if is_tokenizers_available() else None,
599
+ ),
600
+ ),
601
+ (
602
+ "qwen3_next",
603
+ (
604
+ "Qwen2Tokenizer",
605
+ "Qwen2TokenizerFast" if is_tokenizers_available() else None,
606
+ ),
607
+ ),
608
+ ("qwen3_omni_moe", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
609
+ ("qwen3_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
610
+ ("qwen3_vl_moe", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
611
+ ("rag", ("RagTokenizer", None)),
612
+ ("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
613
+ (
614
+ "recurrent_gemma",
615
+ (
616
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
617
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
618
+ ),
619
+ ),
620
+ (
621
+ "reformer",
622
+ (
623
+ "ReformerTokenizer" if is_sentencepiece_available() else None,
624
+ "ReformerTokenizerFast" if is_tokenizers_available() else None,
625
+ ),
626
+ ),
627
+ (
628
+ "rembert",
629
+ (
630
+ "RemBertTokenizer" if is_sentencepiece_available() else None,
631
+ "RemBertTokenizerFast" if is_tokenizers_available() else None,
632
+ ),
633
+ ),
634
+ ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
635
+ ("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
636
+ (
637
+ "roberta-prelayernorm",
638
+ ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None),
639
+ ),
640
+ ("roc_bert", ("RoCBertTokenizer", None)),
641
+ ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
642
+ ("rwkv", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
643
+ (
644
+ "seamless_m4t",
645
+ (
646
+ "SeamlessM4TTokenizer" if is_sentencepiece_available() else None,
647
+ "SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
648
+ ),
649
+ ),
650
+ (
651
+ "seamless_m4t_v2",
652
+ (
653
+ "SeamlessM4TTokenizer" if is_sentencepiece_available() else None,
654
+ "SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
655
+ ),
656
+ ),
657
+ (
658
+ "shieldgemma2",
659
+ (
660
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
661
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
662
+ ),
663
+ ),
664
+ ("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)),
665
+ (
666
+ "siglip2",
667
+ (
668
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
669
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
670
+ ),
671
+ ),
672
+ ("smollm3", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
673
+ ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
674
+ ("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
675
+ ("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),
676
+ ("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
677
+ (
678
+ "squeezebert",
679
+ ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
680
+ ),
681
+ ("stablelm", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
682
+ ("starcoder2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
683
+ (
684
+ "switch_transformers",
685
+ (
686
+ "T5Tokenizer" if is_sentencepiece_available() else None,
687
+ "T5TokenizerFast" if is_tokenizers_available() else None,
688
+ ),
689
+ ),
690
+ (
691
+ "t5",
692
+ (
693
+ "T5Tokenizer" if is_sentencepiece_available() else None,
694
+ "T5TokenizerFast" if is_tokenizers_available() else None,
695
+ ),
696
+ ),
697
+ (
698
+ "t5gemma",
699
+ (
700
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
701
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
702
+ ),
703
+ ),
704
+ ("tapas", ("TapasTokenizer", None)),
705
+ ("tapex", ("TapexTokenizer", None)),
706
+ ("transfo-xl", ("TransfoXLTokenizer", None)),
707
+ ("tvp", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
708
+ (
709
+ "udop",
710
+ (
711
+ "UdopTokenizer" if is_sentencepiece_available() else None,
712
+ "UdopTokenizerFast" if is_tokenizers_available() else None,
713
+ ),
714
+ ),
715
+ (
716
+ "umt5",
717
+ (
718
+ "T5Tokenizer" if is_sentencepiece_available() else None,
719
+ "T5TokenizerFast" if is_tokenizers_available() else None,
720
+ ),
721
+ ),
722
+ ("video_llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
723
+ ("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
724
+ ("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
725
+ ("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
726
+ ("vits", ("VitsTokenizer", None)),
727
+ (
728
+ "voxtral",
729
+ (
730
+ "MistralCommonTokenizer" if is_mistral_common_available() else None,
731
+ "PreTrainedTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
732
+ ),
733
+ ),
734
+ ("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
735
+ ("wav2vec2-bert", ("Wav2Vec2CTCTokenizer", None)),
736
+ ("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
737
+ ("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
738
+ ("whisper", ("WhisperTokenizer", "WhisperTokenizerFast" if is_tokenizers_available() else None)),
739
+ ("xclip", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
740
+ (
741
+ "xglm",
742
+ (
743
+ "XGLMTokenizer" if is_sentencepiece_available() else None,
744
+ "XGLMTokenizerFast" if is_tokenizers_available() else None,
745
+ ),
746
+ ),
747
+ ("xlm", ("XLMTokenizer", None)),
748
+ ("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
749
+ (
750
+ "xlm-roberta",
751
+ (
752
+ "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
753
+ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
754
+ ),
755
+ ),
756
+ (
757
+ "xlm-roberta-xl",
758
+ (
759
+ "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
760
+ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
761
+ ),
762
+ ),
763
+ (
764
+ "xlnet",
765
+ (
766
+ "XLNetTokenizer" if is_sentencepiece_available() else None,
767
+ "XLNetTokenizerFast" if is_tokenizers_available() else None,
768
+ ),
769
+ ),
770
+ ("xlstm", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
771
+ (
772
+ "xmod",
773
+ (
774
+ "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
775
+ "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
776
+ ),
777
+ ),
778
+ (
779
+ "yoso",
780
+ (
781
+ "AlbertTokenizer" if is_sentencepiece_available() else None,
782
+ "AlbertTokenizerFast" if is_tokenizers_available() else None,
783
+ ),
784
+ ),
785
+ (
786
+ "zamba",
787
+ (
788
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
789
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
790
+ ),
791
+ ),
792
+ (
793
+ "zamba2",
794
+ (
795
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
796
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
797
+ ),
798
+ ),
799
+ ]
800
+ )
801
+
802
+ TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
803
+
804
+ CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
805
+
806
+
807
+ def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]:
808
+ if class_name == "PreTrainedTokenizerFast":
809
+ return PreTrainedTokenizerFast
810
+
811
+ for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
812
+ if class_name in tokenizers:
813
+ module_name = model_type_to_module_name(module_name)
814
+ if module_name in ["mistral", "mixtral", "ministral"] and class_name == "MistralCommonTokenizer":
815
+ module = importlib.import_module(".tokenization_mistral_common", "transformers")
816
+ else:
817
+ module = importlib.import_module(f".{module_name}", "transformers.models")
818
+ try:
819
+ return getattr(module, class_name)
820
+ except AttributeError:
821
+ continue
822
+
823
+ for tokenizers in TOKENIZER_MAPPING._extra_content.values():
824
+ for tokenizer in tokenizers:
825
+ if getattr(tokenizer, "__name__", None) == class_name:
826
+ return tokenizer
827
+
828
+ # We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
829
+ # init and we return the proper dummy to get an appropriate error message.
830
+ main_module = importlib.import_module("transformers")
831
+ if hasattr(main_module, class_name):
832
+ return getattr(main_module, class_name)
833
+
834
+ return None
835
+
836
+
837
+ def get_tokenizer_config(
838
+ pretrained_model_name_or_path: Union[str, os.PathLike[str]],
839
+ cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
840
+ force_download: bool = False,
841
+ resume_download: Optional[bool] = None,
842
+ proxies: Optional[dict[str, str]] = None,
843
+ token: Optional[Union[bool, str]] = None,
844
+ revision: Optional[str] = None,
845
+ local_files_only: bool = False,
846
+ subfolder: str = "",
847
+ **kwargs,
848
+ ) -> dict[str, Any]:
849
+ """
850
+ Loads the tokenizer configuration from a pretrained model tokenizer configuration.
851
+
852
+ Args:
853
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
854
+ This can be either:
855
+
856
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
857
+ huggingface.co.
858
+ - a path to a *directory* containing a configuration file saved using the
859
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
860
+
861
+ cache_dir (`str` or `os.PathLike`, *optional*):
862
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
863
+ cache should not be used.
864
+ force_download (`bool`, *optional*, defaults to `False`):
865
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
866
+ exist.
867
+ resume_download:
868
+ Deprecated and ignored. All downloads are now resumed by default when possible.
869
+ Will be removed in v5 of Transformers.
870
+ proxies (`dict[str, str]`, *optional*):
871
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
872
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
873
+ token (`str` or *bool*, *optional*):
874
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
875
+ when running `hf auth login` (stored in `~/.huggingface`).
876
+ revision (`str`, *optional*, defaults to `"main"`):
877
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
878
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
879
+ identifier allowed by git.
880
+ local_files_only (`bool`, *optional*, defaults to `False`):
881
+ If `True`, will only try to load the tokenizer configuration from local files.
882
+ subfolder (`str`, *optional*, defaults to `""`):
883
+ In case the tokenizer config is located inside a subfolder of the model repo on huggingface.co, you can
884
+ specify the folder name here.
885
+
886
+ <Tip>
887
+
888
+ Passing `token=True` is required when you want to use a private model.
889
+
890
+ </Tip>
891
+
892
+ Returns:
893
+ `dict`: The configuration of the tokenizer.
894
+
895
+ Examples:
896
+
897
+ ```python
898
+ # Download configuration from huggingface.co and cache.
899
+ tokenizer_config = get_tokenizer_config("google-bert/bert-base-uncased")
900
+ # This model does not have a tokenizer config so the result will be an empty dict.
901
+ tokenizer_config = get_tokenizer_config("FacebookAI/xlm-roberta-base")
902
+
903
+ # Save a pretrained tokenizer locally and you can reload its config
904
+ from transformers import AutoTokenizer
905
+
906
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
907
+ tokenizer.save_pretrained("tokenizer-test")
908
+ tokenizer_config = get_tokenizer_config("tokenizer-test")
909
+ ```"""
910
+ use_auth_token = kwargs.pop("use_auth_token", None)
911
+ if use_auth_token is not None:
912
+ warnings.warn(
913
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
914
+ FutureWarning,
915
+ )
916
+ if token is not None:
917
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
918
+ token = use_auth_token
919
+
920
+ commit_hash = kwargs.get("_commit_hash")
921
+ resolved_config_file = cached_file(
922
+ pretrained_model_name_or_path,
923
+ TOKENIZER_CONFIG_FILE,
924
+ cache_dir=cache_dir,
925
+ force_download=force_download,
926
+ resume_download=resume_download,
927
+ proxies=proxies,
928
+ token=token,
929
+ revision=revision,
930
+ local_files_only=local_files_only,
931
+ subfolder=subfolder,
932
+ _raise_exceptions_for_gated_repo=False,
933
+ _raise_exceptions_for_missing_entries=False,
934
+ _raise_exceptions_for_connection_errors=False,
935
+ _commit_hash=commit_hash,
936
+ )
937
+ if resolved_config_file is None:
938
+ logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
939
+ return {}
940
+ commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
941
+
942
+ with open(resolved_config_file, encoding="utf-8") as reader:
943
+ result = json.load(reader)
944
+ result["_commit_hash"] = commit_hash
945
+ return result
946
+
947
+
948
+ class AutoTokenizer:
949
+ r"""
950
+ This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
951
+ created with the [`AutoTokenizer.from_pretrained`] class method.
952
+
953
+ This class cannot be instantiated directly using `__init__()` (throws an error).
954
+ """
955
+
956
+ def __init__(self):
957
+ raise OSError(
958
+ "AutoTokenizer is designed to be instantiated "
959
+ "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
960
+ )
961
+
962
+ @classmethod
963
+ @replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES)
964
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
965
+ r"""
966
+ Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.
967
+
968
+ The tokenizer class to instantiate is selected based on the `model_type` property of the config object (either
969
+ passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
970
+ falling back to using pattern matching on `pretrained_model_name_or_path`:
971
+
972
+ List options
973
+
974
+ Params:
975
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
976
+ Can be either:
977
+
978
+ - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
979
+ - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
980
+ using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
981
+ - A path or url to a single saved vocabulary file if and only if the tokenizer only requires a
982
+ single vocabulary file (like Bert or XLNet), e.g.: `./my_model_directory/vocab.txt`. (Not
983
+ applicable to all derived classes)
984
+ inputs (additional positional arguments, *optional*):
985
+ Will be passed along to the Tokenizer `__init__()` method.
986
+ config ([`PretrainedConfig`], *optional*)
987
+ The configuration object used to determine the tokenizer class to instantiate.
988
+ cache_dir (`str` or `os.PathLike`, *optional*):
989
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
990
+ standard cache should not be used.
991
+ force_download (`bool`, *optional*, defaults to `False`):
992
+ Whether or not to force the (re-)download the model weights and configuration files and override the
993
+ cached versions if they exist.
994
+ resume_download:
995
+ Deprecated and ignored. All downloads are now resumed by default when possible.
996
+ Will be removed in v5 of Transformers.
997
+ proxies (`dict[str, str]`, *optional*):
998
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
999
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1000
+ revision (`str`, *optional*, defaults to `"main"`):
1001
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
1002
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
1003
+ identifier allowed by git.
1004
+ subfolder (`str`, *optional*):
1005
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
1006
+ facebook/rag-token-base), specify it here.
1007
+ use_fast (`bool`, *optional*, defaults to `True`):
1008
+ Use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported for
1009
+ a given model. If a fast tokenizer is not available for a given model, a normal Python-based tokenizer
1010
+ is returned instead.
1011
+ tokenizer_type (`str`, *optional*):
1012
+ Tokenizer type to be loaded.
1013
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
1014
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
1015
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
1016
+ execute code present on the Hub on your local machine.
1017
+ kwargs (additional keyword arguments, *optional*):
1018
+ Will be passed to the Tokenizer `__init__()` method. Can be used to set special tokens like
1019
+ `bos_token`, `eos_token`, `unk_token`, `sep_token`, `pad_token`, `cls_token`, `mask_token`,
1020
+ `additional_special_tokens`. See parameters in the `__init__()` for more details.
1021
+
1022
+ Examples:
1023
+
1024
+ ```python
1025
+ >>> from transformers import AutoTokenizer
1026
+
1027
+ >>> # Download vocabulary from huggingface.co and cache.
1028
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
1029
+
1030
+ >>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
1031
+ >>> tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
1032
+
1033
+ >>> # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
1034
+ >>> # tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/")
1035
+
1036
+ >>> # Download vocabulary from huggingface.co and define model-specific arguments
1037
+ >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base", add_prefix_space=True)
1038
+ ```"""
1039
+ use_auth_token = kwargs.pop("use_auth_token", None)
1040
+ if use_auth_token is not None:
1041
+ warnings.warn(
1042
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
1043
+ FutureWarning,
1044
+ )
1045
+ if kwargs.get("token") is not None:
1046
+ raise ValueError(
1047
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
1048
+ )
1049
+ kwargs["token"] = use_auth_token
1050
+
1051
+ config = kwargs.pop("config", None)
1052
+ kwargs["_from_auto"] = True
1053
+
1054
+ use_fast = kwargs.pop("use_fast", True)
1055
+ tokenizer_type = kwargs.pop("tokenizer_type", None)
1056
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
1057
+ gguf_file = kwargs.get("gguf_file")
1058
+
1059
+ # First, let's see whether the tokenizer_type is passed so that we can leverage it
1060
+ if tokenizer_type is not None:
1061
+ tokenizer_class = None
1062
+ tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None)
1063
+
1064
+ if tokenizer_class_tuple is None:
1065
+ raise ValueError(
1066
+ f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of "
1067
+ f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES)}."
1068
+ )
1069
+
1070
+ tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple
1071
+
1072
+ if use_fast:
1073
+ if tokenizer_fast_class_name is not None:
1074
+ tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name)
1075
+ else:
1076
+ logger.warning(
1077
+ "`use_fast` is set to `True` but the tokenizer class does not have a fast version. "
1078
+ " Falling back to the slow version."
1079
+ )
1080
+ if tokenizer_class is None:
1081
+ tokenizer_class = tokenizer_class_from_name(tokenizer_class_name)
1082
+
1083
+ if tokenizer_class is None:
1084
+ raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.")
1085
+
1086
+ return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
1087
+
1088
+ # Next, let's try to use the tokenizer_config file to get the tokenizer class.
1089
+ tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
1090
+ if "_commit_hash" in tokenizer_config:
1091
+ kwargs["_commit_hash"] = tokenizer_config["_commit_hash"]
1092
+ config_tokenizer_class = tokenizer_config.get("tokenizer_class")
1093
+ tokenizer_auto_map = None
1094
+ if "auto_map" in tokenizer_config:
1095
+ if isinstance(tokenizer_config["auto_map"], (tuple, list)):
1096
+ # Legacy format for dynamic tokenizers
1097
+ tokenizer_auto_map = tokenizer_config["auto_map"]
1098
+ else:
1099
+ tokenizer_auto_map = tokenizer_config["auto_map"].get("AutoTokenizer", None)
1100
+
1101
+ # If that did not work, let's try to use the config.
1102
+ if config_tokenizer_class is None:
1103
+ if not isinstance(config, PretrainedConfig):
1104
+ if gguf_file:
1105
+ gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **kwargs)
1106
+ config_dict = load_gguf_checkpoint(gguf_path, return_tensors=False)["config"]
1107
+ config = AutoConfig.for_model(**config_dict)
1108
+ else:
1109
+ config = AutoConfig.from_pretrained(
1110
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
1111
+ )
1112
+ config_tokenizer_class = config.tokenizer_class
1113
+ if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
1114
+ tokenizer_auto_map = config.auto_map["AutoTokenizer"]
1115
+
1116
+ has_remote_code = tokenizer_auto_map is not None
1117
+ has_local_code = type(config) in TOKENIZER_MAPPING or (
1118
+ config_tokenizer_class is not None
1119
+ and (
1120
+ tokenizer_class_from_name(config_tokenizer_class) is not None
1121
+ or tokenizer_class_from_name(config_tokenizer_class + "Fast") is not None
1122
+ )
1123
+ )
1124
+ if has_remote_code:
1125
+ if use_fast and tokenizer_auto_map[1] is not None:
1126
+ class_ref = tokenizer_auto_map[1]
1127
+ else:
1128
+ class_ref = tokenizer_auto_map[0]
1129
+ if "--" in class_ref:
1130
+ upstream_repo = class_ref.split("--")[0]
1131
+ else:
1132
+ upstream_repo = None
1133
+ trust_remote_code = resolve_trust_remote_code(
1134
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
1135
+ )
1136
+
1137
+ if has_remote_code and trust_remote_code:
1138
+ tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
1139
+ _ = kwargs.pop("code_revision", None)
1140
+ tokenizer_class.register_for_auto_class()
1141
+ return tokenizer_class.from_pretrained(
1142
+ pretrained_model_name_or_path, *inputs, trust_remote_code=trust_remote_code, **kwargs
1143
+ )
1144
+ elif config_tokenizer_class is not None:
1145
+ tokenizer_class = None
1146
+ if use_fast and not config_tokenizer_class.endswith("Fast"):
1147
+ tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
1148
+ tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
1149
+ if tokenizer_class is None:
1150
+ tokenizer_class_candidate = config_tokenizer_class
1151
+ tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
1152
+ if tokenizer_class is None:
1153
+ raise ValueError(
1154
+ f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
1155
+ )
1156
+ return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
1157
+
1158
+ # Otherwise we have to be creative.
1159
+ # if model is an encoder decoder, the encoder tokenizer class is used by default
1160
+ if isinstance(config, EncoderDecoderConfig):
1161
+ if type(config.decoder) is not type(config.encoder):
1162
+ logger.warning(
1163
+ f"The encoder model config class: {config.encoder.__class__} is different from the decoder model "
1164
+ f"config class: {config.decoder.__class__}. It is not recommended to use the "
1165
+ "`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder "
1166
+ "specific tokenizer classes."
1167
+ )
1168
+ config = config.encoder
1169
+
1170
+ model_type = config_class_to_model_type(type(config).__name__)
1171
+ if model_type is not None:
1172
+ tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
1173
+
1174
+ if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
1175
+ return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
1176
+ else:
1177
+ if tokenizer_class_py is not None:
1178
+ return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
1179
+ else:
1180
+ raise ValueError(
1181
+ "This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed "
1182
+ "in order to use this tokenizer."
1183
+ )
1184
+
1185
+ raise ValueError(
1186
+ f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n"
1187
+ f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING)}."
1188
+ )
1189
+
1190
+ @staticmethod
1191
+ def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None, exist_ok=False):
1192
+ """
1193
+ Register a new tokenizer in this mapping.
1194
+
1195
+
1196
+ Args:
1197
+ config_class ([`PretrainedConfig`]):
1198
+ The configuration corresponding to the model to register.
1199
+ slow_tokenizer_class ([`PretrainedTokenizer`], *optional*):
1200
+ The slow tokenizer to register.
1201
+ fast_tokenizer_class ([`PretrainedTokenizerFast`], *optional*):
1202
+ The fast tokenizer to register.
1203
+ """
1204
+ if slow_tokenizer_class is None and fast_tokenizer_class is None:
1205
+ raise ValueError("You need to pass either a `slow_tokenizer_class` or a `fast_tokenizer_class")
1206
+ if slow_tokenizer_class is not None and issubclass(slow_tokenizer_class, PreTrainedTokenizerFast):
1207
+ raise ValueError("You passed a fast tokenizer in the `slow_tokenizer_class`.")
1208
+ if fast_tokenizer_class is not None and issubclass(fast_tokenizer_class, PreTrainedTokenizer):
1209
+ raise ValueError("You passed a slow tokenizer in the `fast_tokenizer_class`.")
1210
+
1211
+ if (
1212
+ slow_tokenizer_class is not None
1213
+ and fast_tokenizer_class is not None
1214
+ and issubclass(fast_tokenizer_class, PreTrainedTokenizerFast)
1215
+ and fast_tokenizer_class.slow_tokenizer_class != slow_tokenizer_class
1216
+ ):
1217
+ raise ValueError(
1218
+ "The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not "
1219
+ "consistent with the slow tokenizer class you passed (fast tokenizer has "
1220
+ f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those "
1221
+ "so they match!"
1222
+ )
1223
+
1224
+ # Avoid resetting a set slow/fast tokenizer if we are passing just the other ones.
1225
+ if config_class in TOKENIZER_MAPPING._extra_content:
1226
+ existing_slow, existing_fast = TOKENIZER_MAPPING[config_class]
1227
+ if slow_tokenizer_class is None:
1228
+ slow_tokenizer_class = existing_slow
1229
+ if fast_tokenizer_class is None:
1230
+ fast_tokenizer_class = existing_fast
1231
+
1232
+ TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class), exist_ok=exist_ok)
1233
+
1234
+
1235
+ __all__ = ["TOKENIZER_MAPPING", "AutoTokenizer"]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/video_processing_auto.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """AutoVideoProcessor class."""
16
+
17
+ import importlib
18
+ import json
19
+ import os
20
+ import warnings
21
+ from collections import OrderedDict
22
+ from typing import TYPE_CHECKING, Optional, Union
23
+
24
+ # Build the list of all video processors
25
+ from ...configuration_utils import PretrainedConfig
26
+ from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
27
+ from ...utils import CONFIG_NAME, VIDEO_PROCESSOR_NAME, cached_file, is_torchvision_available, logging
28
+ from ...utils.import_utils import requires
29
+ from ...video_processing_utils import BaseVideoProcessor
30
+ from .auto_factory import _LazyAutoMapping
31
+ from .configuration_auto import (
32
+ CONFIG_MAPPING_NAMES,
33
+ AutoConfig,
34
+ model_type_to_module_name,
35
+ replace_list_option_in_docstrings,
36
+ )
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ if TYPE_CHECKING:
43
+ # This significantly improves completion suggestion performance when
44
+ # the transformers package is used with Microsoft's Pylance language server.
45
+ VIDEO_PROCESSOR_MAPPING_NAMES: OrderedDict[str, tuple[Optional[str], Optional[str]]] = OrderedDict()
46
+ else:
47
+ VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict(
48
+ [
49
+ ("glm4v", "Glm4vVideoProcessor"),
50
+ ("instructblip", "InstructBlipVideoVideoProcessor"),
51
+ ("instructblipvideo", "InstructBlipVideoVideoProcessor"),
52
+ ("internvl", "InternVLVideoProcessor"),
53
+ ("llava_next_video", "LlavaNextVideoVideoProcessor"),
54
+ ("llava_onevision", "LlavaOnevisionVideoProcessor"),
55
+ ("perception_lm", "PerceptionLMVideoProcessor"),
56
+ ("qwen2_5_omni", "Qwen2VLVideoProcessor"),
57
+ ("qwen2_5_vl", "Qwen2VLVideoProcessor"),
58
+ ("qwen2_vl", "Qwen2VLVideoProcessor"),
59
+ ("qwen3_omni_moe", "Qwen2VLVideoProcessor"),
60
+ ("qwen3_vl", "Qwen3VLVideoProcessor"),
61
+ ("qwen3_vl_moe", "Qwen3VLVideoProcessor"),
62
+ ("sam2_video", "Sam2VideoVideoProcessor"),
63
+ ("smolvlm", "SmolVLMVideoProcessor"),
64
+ ("video_llava", "VideoLlavaVideoProcessor"),
65
+ ("vjepa2", "VJEPA2VideoProcessor"),
66
+ ]
67
+ )
68
+
69
+ for model_type, video_processors in VIDEO_PROCESSOR_MAPPING_NAMES.items():
70
+ fast_video_processor_class = video_processors
71
+
72
+ # If the torchvision is not available, we set it to None
73
+ if not is_torchvision_available():
74
+ fast_video_processor_class = None
75
+
76
+ VIDEO_PROCESSOR_MAPPING_NAMES[model_type] = fast_video_processor_class
77
+
78
+ VIDEO_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, VIDEO_PROCESSOR_MAPPING_NAMES)
79
+
80
+
81
+ def video_processor_class_from_name(class_name: str):
82
+ for module_name, extractors in VIDEO_PROCESSOR_MAPPING_NAMES.items():
83
+ if class_name in extractors:
84
+ module_name = model_type_to_module_name(module_name)
85
+
86
+ module = importlib.import_module(f".{module_name}", "transformers.models")
87
+ try:
88
+ return getattr(module, class_name)
89
+ except AttributeError:
90
+ continue
91
+
92
+ for extractor in VIDEO_PROCESSOR_MAPPING._extra_content.values():
93
+ if getattr(extractor, "__name__", None) == class_name:
94
+ return extractor
95
+
96
+ # We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main
97
+ # init and we return the proper dummy to get an appropriate error message.
98
+ main_module = importlib.import_module("transformers")
99
+ if hasattr(main_module, class_name):
100
+ return getattr(main_module, class_name)
101
+
102
+ return None
103
+
104
+
105
+ def get_video_processor_config(
106
+ pretrained_model_name_or_path: Union[str, os.PathLike],
107
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
108
+ force_download: bool = False,
109
+ resume_download: Optional[bool] = None,
110
+ proxies: Optional[dict[str, str]] = None,
111
+ token: Optional[Union[bool, str]] = None,
112
+ revision: Optional[str] = None,
113
+ local_files_only: bool = False,
114
+ **kwargs,
115
+ ):
116
+ """
117
+ Loads the video processor configuration from a pretrained model video processor configuration.
118
+
119
+ Args:
120
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
121
+ This can be either:
122
+
123
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
124
+ huggingface.co.
125
+ - a path to a *directory* containing a configuration file saved using the
126
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
127
+
128
+ cache_dir (`str` or `os.PathLike`, *optional*):
129
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
130
+ cache should not be used.
131
+ force_download (`bool`, *optional*, defaults to `False`):
132
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
133
+ exist.
134
+ resume_download:
135
+ Deprecated and ignored. All downloads are now resumed by default when possible.
136
+ Will be removed in v5 of Transformers.
137
+ proxies (`dict[str, str]`, *optional*):
138
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
139
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
140
+ token (`str` or *bool*, *optional*):
141
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
142
+ when running `hf auth login` (stored in `~/.huggingface`).
143
+ revision (`str`, *optional*, defaults to `"main"`):
144
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
145
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
146
+ identifier allowed by git.
147
+ local_files_only (`bool`, *optional*, defaults to `False`):
148
+ If `True`, will only try to load the video processor configuration from local files.
149
+
150
+ <Tip>
151
+
152
+ Passing `token=True` is required when you want to use a private model.
153
+
154
+ </Tip>
155
+
156
+ Returns:
157
+ `Dict`: The configuration of the video processor.
158
+
159
+ Examples:
160
+
161
+ ```python
162
+ # Download configuration from huggingface.co and cache.
163
+ video_processor_config = get_video_processor_config("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
164
+ # This model does not have a video processor config so the result will be an empty dict.
165
+ video_processor_config = get_video_processor_config("FacebookAI/xlm-roberta-base")
166
+
167
+ # Save a pretrained video processor locally and you can reload its config
168
+ from transformers import AutoVideoProcessor
169
+
170
+ video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
171
+ video_processor.save_pretrained("video-processor-test")
172
+ video_processor = get_video_processor_config("video-processor-test")
173
+ ```"""
174
+ use_auth_token = kwargs.pop("use_auth_token", None)
175
+ if use_auth_token is not None:
176
+ warnings.warn(
177
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
178
+ FutureWarning,
179
+ )
180
+ if token is not None:
181
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
182
+ token = use_auth_token
183
+
184
+ resolved_config_file = cached_file(
185
+ pretrained_model_name_or_path,
186
+ VIDEO_PROCESSOR_NAME,
187
+ cache_dir=cache_dir,
188
+ force_download=force_download,
189
+ resume_download=resume_download,
190
+ proxies=proxies,
191
+ token=token,
192
+ revision=revision,
193
+ local_files_only=local_files_only,
194
+ )
195
+ if resolved_config_file is None:
196
+ logger.info(
197
+ "Could not locate the video processor configuration file, will try to use the model config instead."
198
+ )
199
+ return {}
200
+
201
+ with open(resolved_config_file, encoding="utf-8") as reader:
202
+ return json.load(reader)
203
+
204
+
205
+ @requires(backends=("vision", "torchvision"))
206
+ class AutoVideoProcessor:
207
+ r"""
208
+ This is a generic video processor class that will be instantiated as one of the video processor classes of the
209
+ library when created with the [`AutoVideoProcessor.from_pretrained`] class method.
210
+
211
+ This class cannot be instantiated directly using `__init__()` (throws an error).
212
+ """
213
+
214
+ def __init__(self):
215
+ raise OSError(
216
+ "AutoVideoProcessor is designed to be instantiated "
217
+ "using the `AutoVideoProcessor.from_pretrained(pretrained_model_name_or_path)` method."
218
+ )
219
+
220
+ @classmethod
221
+ @replace_list_option_in_docstrings(VIDEO_PROCESSOR_MAPPING_NAMES)
222
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
223
+ r"""
224
+ Instantiate one of the video processor classes of the library from a pretrained model vocabulary.
225
+
226
+ The video processor class to instantiate is selected based on the `model_type` property of the config object
227
+ (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
228
+ missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
229
+
230
+ List options
231
+
232
+ Params:
233
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
234
+ This can be either:
235
+
236
+ - a string, the *model id* of a pretrained video_processor hosted inside a model repo on
237
+ huggingface.co.
238
+ - a path to a *directory* containing a video processor file saved using the
239
+ [`~video_processing_utils.BaseVideoProcessor.save_pretrained`] method, e.g.,
240
+ `./my_model_directory/`.
241
+ - a path or url to a saved video processor JSON *file*, e.g.,
242
+ `./my_model_directory/preprocessor_config.json`.
243
+ cache_dir (`str` or `os.PathLike`, *optional*):
244
+ Path to a directory in which a downloaded pretrained model video processor should be cached if the
245
+ standard cache should not be used.
246
+ force_download (`bool`, *optional*, defaults to `False`):
247
+ Whether or not to force to (re-)download the video processor files and override the cached versions if
248
+ they exist.
249
+ resume_download:
250
+ Deprecated and ignored. All downloads are now resumed by default when possible.
251
+ Will be removed in v5 of Transformers.
252
+ proxies (`dict[str, str]`, *optional*):
253
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
254
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
255
+ token (`str` or *bool*, *optional*):
256
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
257
+ when running `hf auth login` (stored in `~/.huggingface`).
258
+ revision (`str`, *optional*, defaults to `"main"`):
259
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
260
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
261
+ identifier allowed by git.
262
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
263
+ If `False`, then this function returns just the final video processor object. If `True`, then this
264
+ functions returns a `Tuple(video_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
265
+ consisting of the key/value pairs whose keys are not video processor attributes: i.e., the part of
266
+ `kwargs` which has not been used to update `video_processor` and is otherwise ignored.
267
+ trust_remote_code (`bool`, *optional*, defaults to `False`):
268
+ Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
269
+ should only be set to `True` for repositories you trust and in which you have read the code, as it will
270
+ execute code present on the Hub on your local machine.
271
+ kwargs (`dict[str, Any]`, *optional*):
272
+ The values in kwargs of any keys which are video processor attributes will be used to override the
273
+ loaded values. Behavior concerning key/value pairs whose keys are *not* video processor attributes is
274
+ controlled by the `return_unused_kwargs` keyword parameter.
275
+
276
+ <Tip>
277
+
278
+ Passing `token=True` is required when you want to use a private model.
279
+
280
+ </Tip>
281
+
282
+ Examples:
283
+
284
+ ```python
285
+ >>> from transformers import AutoVideoProcessor
286
+
287
+ >>> # Download video processor from huggingface.co and cache.
288
+ >>> video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
289
+
290
+ >>> # If video processor files are in a directory (e.g. video processor was saved using *save_pretrained('./test/saved_model/')*)
291
+ >>> # video_processor = AutoVideoProcessor.from_pretrained("./test/saved_model/")
292
+ ```"""
293
+ use_auth_token = kwargs.pop("use_auth_token", None)
294
+ if use_auth_token is not None:
295
+ warnings.warn(
296
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
297
+ FutureWarning,
298
+ )
299
+ if kwargs.get("token") is not None:
300
+ raise ValueError(
301
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
302
+ )
303
+ kwargs["token"] = use_auth_token
304
+
305
+ config = kwargs.pop("config", None)
306
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
307
+ kwargs["_from_auto"] = True
308
+
309
+ config_dict, _ = BaseVideoProcessor.get_video_processor_dict(pretrained_model_name_or_path, **kwargs)
310
+ video_processor_class = config_dict.get("video_processor_type", None)
311
+ video_processor_auto_map = None
312
+ if "AutoVideoProcessor" in config_dict.get("auto_map", {}):
313
+ video_processor_auto_map = config_dict["auto_map"]["AutoVideoProcessor"]
314
+
315
+ # If we still don't have the video processor class, check if we're loading from a previous image processor config
316
+ # and if so, infer the video processor class from there.
317
+ if video_processor_class is None and video_processor_auto_map is None:
318
+ image_processor_class = config_dict.pop("image_processor_type", None)
319
+ if image_processor_class is not None:
320
+ video_processor_class_inferred = image_processor_class.replace("ImageProcessor", "VideoProcessor")
321
+
322
+ # Some models have different image processors, e.g. InternVL uses GotOCRImageProcessor
323
+ # We cannot use GotOCRVideoProcessor when falling back for BC and should try to infer from config later on
324
+ if video_processor_class_inferred in VIDEO_PROCESSOR_MAPPING_NAMES.values():
325
+ video_processor_class = video_processor_class_inferred
326
+ if "AutoImageProcessor" in config_dict.get("auto_map", {}):
327
+ image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
328
+ video_processor_auto_map = image_processor_auto_map.replace("ImageProcessor", "VideoProcessor")
329
+
330
+ # If we don't find the video processor class in the video processor config, let's try the model config.
331
+ if video_processor_class is None and video_processor_auto_map is None:
332
+ if not isinstance(config, PretrainedConfig):
333
+ config = AutoConfig.from_pretrained(
334
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
335
+ )
336
+ # It could be in `config.video_processor_type``
337
+ video_processor_class = getattr(config, "video_processor_type", None)
338
+ if hasattr(config, "auto_map") and "AutoVideoProcessor" in config.auto_map:
339
+ video_processor_auto_map = config.auto_map["AutoVideoProcessor"]
340
+
341
+ if video_processor_class is not None:
342
+ video_processor_class = video_processor_class_from_name(video_processor_class)
343
+
344
+ has_remote_code = video_processor_auto_map is not None
345
+ has_local_code = video_processor_class is not None or type(config) in VIDEO_PROCESSOR_MAPPING
346
+ trust_remote_code = resolve_trust_remote_code(
347
+ trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
348
+ )
349
+
350
+ if has_remote_code and trust_remote_code:
351
+ class_ref = video_processor_auto_map
352
+ video_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
353
+ _ = kwargs.pop("code_revision", None)
354
+ video_processor_class.register_for_auto_class()
355
+ return video_processor_class.from_dict(config_dict, **kwargs)
356
+ elif video_processor_class is not None:
357
+ return video_processor_class.from_dict(config_dict, **kwargs)
358
+ # Last try: we use the VIDEO_PROCESSOR_MAPPING.
359
+ elif type(config) in VIDEO_PROCESSOR_MAPPING:
360
+ video_processor_class = VIDEO_PROCESSOR_MAPPING[type(config)]
361
+
362
+ if video_processor_class is not None:
363
+ return video_processor_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
364
+ else:
365
+ raise ValueError(
366
+ "This video processor cannot be instantiated. Please make sure you have `torchvision` installed."
367
+ )
368
+
369
+ raise ValueError(
370
+ f"Unrecognized video processor in {pretrained_model_name_or_path}. Should have a "
371
+ f"`video_processor_type` key in its {VIDEO_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
372
+ f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in VIDEO_PROCESSOR_MAPPING_NAMES)}"
373
+ )
374
+
375
+ @staticmethod
376
+ def register(
377
+ config_class,
378
+ video_processor_class,
379
+ exist_ok=False,
380
+ ):
381
+ """
382
+ Register a new video processor for this class.
383
+
384
+ Args:
385
+ config_class ([`PretrainedConfig`]):
386
+ The configuration corresponding to the model to register.
387
+ video_processor_class ([`BaseVideoProcessor`]):
388
+ The video processor to register.
389
+ """
390
+ VIDEO_PROCESSOR_MAPPING.register(config_class, video_processor_class, exist_ok=exist_ok)
391
+
392
+
393
+ __all__ = ["VIDEO_PROCESSOR_MAPPING", "AutoVideoProcessor"]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_bark import *
22
+ from .modeling_bark import *
23
+ from .processing_bark import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/configuration_bark.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """BARK model configuration"""
16
+
17
+ from typing import Optional
18
+
19
+ from ...configuration_utils import PretrainedConfig
20
+ from ...utils import add_start_docstrings, logging
21
+ from ..auto import CONFIG_MAPPING, AutoConfig
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ BARK_SUBMODELCONFIG_START_DOCSTRING = """
28
+ This is the configuration class to store the configuration of a [`{model}`]. It is used to instantiate the model
29
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30
+ defaults will yield a similar configuration to that of the Bark [suno/bark](https://huggingface.co/suno/bark)
31
+ architecture.
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+ Args:
37
+ block_size (`int`, *optional*, defaults to 1024):
38
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
39
+ just in case (e.g., 512 or 1024 or 2048).
40
+ input_vocab_size (`int`, *optional*, defaults to 10_048):
41
+ Vocabulary size of a Bark sub-model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`{model}`]. Defaults to 10_048 but should be carefully thought with
43
+ regards to the chosen sub-model.
44
+ output_vocab_size (`int`, *optional*, defaults to 10_048):
45
+ Output vocabulary size of a Bark sub-model. Defines the number of different tokens that can be represented
46
+ by the: `output_ids` when passing forward a [`{model}`]. Defaults to 10_048 but should be carefully thought
47
+ with regards to the chosen sub-model.
48
+ num_layers (`int`, *optional*, defaults to 12):
49
+ Number of hidden layers in the given sub-model.
50
+ num_heads (`int`, *optional*, defaults to 12):
51
+ Number of attention heads for each attention layer in the Transformer architecture.
52
+ hidden_size (`int`, *optional*, defaults to 768):
53
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the architecture.
54
+ dropout (`float`, *optional*, defaults to 0.0):
55
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
56
+ bias (`bool`, *optional*, defaults to `True`):
57
+ Whether or not to use bias in the linear layers and layer norm layers.
58
+ initializer_range (`float`, *optional*, defaults to 0.02):
59
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
60
+ use_cache (`bool`, *optional*, defaults to `True`):
61
+ Whether or not the model should return the last key/values attentions (not used by all models).
62
+ """
63
+
64
+
65
+ class BarkSubModelConfig(PretrainedConfig):
66
+ keys_to_ignore_at_inference = ["past_key_values"]
67
+
68
+ attribute_map = {
69
+ "num_attention_heads": "num_heads",
70
+ "num_hidden_layers": "num_layers",
71
+ "vocab_size": "input_vocab_size",
72
+ "window_size": "block_size",
73
+ }
74
+
75
+ def __init__(
76
+ self,
77
+ block_size=1024,
78
+ input_vocab_size=10_048,
79
+ output_vocab_size=10_048,
80
+ num_layers=12,
81
+ num_heads=12,
82
+ hidden_size=768,
83
+ dropout=0.0,
84
+ bias=True, # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
85
+ initializer_range=0.02,
86
+ use_cache=True,
87
+ **kwargs,
88
+ ):
89
+ self.block_size = block_size
90
+ self.input_vocab_size = input_vocab_size
91
+ self.output_vocab_size = output_vocab_size
92
+ self.num_layers = num_layers
93
+ self.num_heads = num_heads
94
+ self.hidden_size = hidden_size
95
+ self.dropout = dropout
96
+ self.bias = bias
97
+ self.use_cache = use_cache
98
+ self.initializer_range = initializer_range
99
+
100
+ super().__init__(**kwargs)
101
+
102
+
103
+ @add_start_docstrings(
104
+ BARK_SUBMODELCONFIG_START_DOCSTRING.format(config="BarkSemanticConfig", model="BarkSemanticModel"),
105
+ """
106
+ Example:
107
+
108
+ ```python
109
+ >>> from transformers import BarkSemanticConfig, BarkSemanticModel
110
+
111
+ >>> # Initializing a Bark sub-module style configuration
112
+ >>> configuration = BarkSemanticConfig()
113
+
114
+ >>> # Initializing a model (with random weights) from the suno/bark style configuration
115
+ >>> model = BarkSemanticModel(configuration)
116
+
117
+ >>> # Accessing the model configuration
118
+ >>> configuration = model.config
119
+ ```""",
120
+ )
121
+ class BarkSemanticConfig(BarkSubModelConfig):
122
+ model_type = "semantic"
123
+ base_config_key = "semantic_config"
124
+
125
+
126
+ @add_start_docstrings(
127
+ BARK_SUBMODELCONFIG_START_DOCSTRING.format(config="BarkCoarseConfig", model="BarkCoarseModel"),
128
+ """
129
+ Example:
130
+
131
+ ```python
132
+ >>> from transformers import BarkCoarseConfig, BarkCoarseModel
133
+
134
+ >>> # Initializing a Bark sub-module style configuration
135
+ >>> configuration = BarkCoarseConfig()
136
+
137
+ >>> # Initializing a model (with random weights) from the suno/bark style configuration
138
+ >>> model = BarkCoarseModel(configuration)
139
+
140
+ >>> # Accessing the model configuration
141
+ >>> configuration = model.config
142
+ ```""",
143
+ )
144
+ class BarkCoarseConfig(BarkSubModelConfig):
145
+ model_type = "coarse_acoustics"
146
+ base_config_key = "coarse_acoustics_config"
147
+
148
+
149
+ @add_start_docstrings(
150
+ BARK_SUBMODELCONFIG_START_DOCSTRING.format(config="BarkFineConfig", model="BarkFineModel"),
151
+ """
152
+ n_codes_total (`int`, *optional*, defaults to 8):
153
+ The total number of audio codebooks predicted. Used in the fine acoustics sub-model.
154
+ n_codes_given (`int`, *optional*, defaults to 1):
155
+ The number of audio codebooks predicted in the coarse acoustics sub-model. Used in the acoustics
156
+ sub-models.
157
+ Example:
158
+
159
+ ```python
160
+ >>> from transformers import BarkFineConfig, BarkFineModel
161
+
162
+ >>> # Initializing a Bark sub-module style configuration
163
+ >>> configuration = BarkFineConfig()
164
+
165
+ >>> # Initializing a model (with random weights) from the suno/bark style configuration
166
+ >>> model = BarkFineModel(configuration)
167
+
168
+ >>> # Accessing the model configuration
169
+ >>> configuration = model.config
170
+ ```""",
171
+ )
172
+ class BarkFineConfig(BarkSubModelConfig):
173
+ model_type = "fine_acoustics"
174
+ base_config_key = "fine_acoustics_config"
175
+
176
+ def __init__(self, tie_word_embeddings=True, n_codes_total=8, n_codes_given=1, **kwargs):
177
+ self.n_codes_total = n_codes_total
178
+ self.n_codes_given = n_codes_given
179
+
180
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
181
+
182
+
183
+ class BarkConfig(PretrainedConfig):
184
+ """
185
+ This is the configuration class to store the configuration of a [`BarkModel`]. It is used to instantiate a Bark
186
+ model according to the specified sub-models configurations, defining the model architecture.
187
+
188
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Bark
189
+ [suno/bark](https://huggingface.co/suno/bark) architecture.
190
+
191
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
192
+ documentation from [`PretrainedConfig`] for more information.
193
+
194
+ Args:
195
+ semantic_config ([`BarkSemanticConfig`], *optional*):
196
+ Configuration of the underlying semantic sub-model.
197
+ coarse_acoustics_config ([`BarkCoarseConfig`], *optional*):
198
+ Configuration of the underlying coarse acoustics sub-model.
199
+ fine_acoustics_config ([`BarkFineConfig`], *optional*):
200
+ Configuration of the underlying fine acoustics sub-model.
201
+ codec_config ([`AutoConfig`], *optional*):
202
+ Configuration of the underlying codec sub-model.
203
+
204
+ Example:
205
+
206
+ ```python
207
+ >>> from transformers import (
208
+ ... BarkSemanticConfig,
209
+ ... BarkCoarseConfig,
210
+ ... BarkFineConfig,
211
+ ... BarkModel,
212
+ ... BarkConfig,
213
+ ... AutoConfig,
214
+ ... )
215
+
216
+ >>> # Initializing Bark sub-modules configurations.
217
+ >>> semantic_config = BarkSemanticConfig()
218
+ >>> coarse_acoustics_config = BarkCoarseConfig()
219
+ >>> fine_acoustics_config = BarkFineConfig()
220
+ >>> codec_config = AutoConfig.from_pretrained("facebook/encodec_24khz")
221
+
222
+
223
+ >>> # Initializing a Bark module style configuration
224
+ >>> configuration = BarkConfig.from_sub_model_configs(
225
+ ... semantic_config, coarse_acoustics_config, fine_acoustics_config, codec_config
226
+ ... )
227
+
228
+ >>> # Initializing a model (with random weights)
229
+ >>> model = BarkModel(configuration)
230
+
231
+ >>> # Accessing the model configuration
232
+ >>> configuration = model.config
233
+ ```
234
+ """
235
+
236
+ model_type = "bark"
237
+ sub_configs = {
238
+ "semantic_config": BarkSemanticConfig,
239
+ "coarse_acoustics_config": BarkCoarseConfig,
240
+ "fine_acoustics_config": BarkFineConfig,
241
+ "codec_config": AutoConfig,
242
+ }
243
+
244
+ def __init__(
245
+ self,
246
+ semantic_config: Optional[dict] = None,
247
+ coarse_acoustics_config: Optional[dict] = None,
248
+ fine_acoustics_config: Optional[dict] = None,
249
+ codec_config: Optional[dict] = None,
250
+ initializer_range=0.02,
251
+ **kwargs,
252
+ ):
253
+ if semantic_config is None:
254
+ semantic_config = {}
255
+ logger.info("semantic_config is None. initializing the semantic model with default values.")
256
+
257
+ if coarse_acoustics_config is None:
258
+ coarse_acoustics_config = {}
259
+ logger.info("coarse_acoustics_config is None. initializing the coarse model with default values.")
260
+
261
+ if fine_acoustics_config is None:
262
+ fine_acoustics_config = {}
263
+ logger.info("fine_acoustics_config is None. initializing the fine model with default values.")
264
+
265
+ if codec_config is None:
266
+ codec_config = {}
267
+ logger.info("codec_config is None. initializing the codec model with default values.")
268
+
269
+ self.semantic_config = BarkSemanticConfig(**semantic_config)
270
+ self.coarse_acoustics_config = BarkCoarseConfig(**coarse_acoustics_config)
271
+ self.fine_acoustics_config = BarkFineConfig(**fine_acoustics_config)
272
+ codec_model_type = codec_config.get("model_type", "encodec")
273
+ self.codec_config = CONFIG_MAPPING[codec_model_type](**codec_config)
274
+
275
+ self.initializer_range = initializer_range
276
+
277
+ super().__init__(**kwargs)
278
+
279
+ @classmethod
280
+ def from_sub_model_configs(
281
+ cls,
282
+ semantic_config: BarkSemanticConfig,
283
+ coarse_acoustics_config: BarkCoarseConfig,
284
+ fine_acoustics_config: BarkFineConfig,
285
+ codec_config: PretrainedConfig,
286
+ **kwargs,
287
+ ):
288
+ r"""
289
+ Instantiate a [`BarkConfig`] (or a derived class) from bark sub-models configuration.
290
+
291
+ Returns:
292
+ [`BarkConfig`]: An instance of a configuration object
293
+ """
294
+ return cls(
295
+ semantic_config=semantic_config.to_dict(),
296
+ coarse_acoustics_config=coarse_acoustics_config.to_dict(),
297
+ fine_acoustics_config=fine_acoustics_config.to_dict(),
298
+ codec_config=codec_config.to_dict(),
299
+ **kwargs,
300
+ )
301
+
302
+
303
+ __all__ = ["BarkCoarseConfig", "BarkConfig", "BarkFineConfig", "BarkSemanticConfig"]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/generation_configuration_bark.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """BARK model generation configuration"""
16
+
17
+ import copy
18
+ from typing import Optional
19
+
20
+ from ...generation.configuration_utils import GenerationConfig
21
+ from ...utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class BarkSemanticGenerationConfig(GenerationConfig):
28
+ model_type = "semantic"
29
+
30
+ def __init__(
31
+ self,
32
+ eos_token_id=10_000,
33
+ renormalize_logits=True,
34
+ max_new_tokens=768,
35
+ output_scores=False,
36
+ return_dict_in_generate=False,
37
+ output_hidden_states=False,
38
+ output_attentions=False,
39
+ temperature=1.0,
40
+ do_sample=False,
41
+ text_encoding_offset=10_048,
42
+ text_pad_token=129_595,
43
+ semantic_infer_token=129_599,
44
+ semantic_vocab_size=10_000,
45
+ max_input_semantic_length=256,
46
+ semantic_rate_hz=49.9,
47
+ min_eos_p=None,
48
+ **kwargs,
49
+ ):
50
+ """Class that holds a generation configuration for [`BarkSemanticModel`].
51
+
52
+ This configuration inherit from [`GenerationConfig`] and can be used to control the model generation. Read the
53
+ documentation from [`GenerationConfig`] for more information.
54
+
55
+ Args:
56
+ eos_token_id (`int`, *optional*, defaults to 10_000):
57
+ The id of the *end-of-sequence* token.
58
+ renormalize_logits (`bool`, *optional*, defaults to `True`):
59
+ Whether to renormalize the logits after applying all the logits processors (including the
60
+ custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
61
+ score logits are normalized but some logit processors break the normalization.
62
+ max_new_tokens (`int`, *optional*, defaults to 768):
63
+ The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
64
+ output_scores (`bool`, *optional*, defaults to `False`):
65
+ Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
66
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
67
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
68
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
69
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
70
+ for more details.
71
+ output_attentions (`bool`, *optional*, defaults to `False`):
72
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
73
+ returned tensors for more details.
74
+ temperature (`float`, *optional*, defaults to 1.0):
75
+ The value used to modulate the next token probabilities.
76
+ do_sample (`bool`, *optional*, defaults to `False`):
77
+ Whether or not to use sampling ; use greedy decoding otherwise.
78
+ text_encoding_offset (`int`, *optional*, defaults to 10_048):
79
+ Text encoding offset.
80
+ text_pad_token (`int`, *optional*, defaults to 129_595):
81
+ Text pad token.
82
+ semantic_infer_token (`int`, *optional*, defaults to 129_599):
83
+ Semantic infer token.
84
+ semantic_vocab_size (`int`, *optional*, defaults to 10_000):
85
+ Semantic vocab size.
86
+ max_input_semantic_length (`int`, *optional*, defaults to 256):
87
+ Max length of semantic input vector.
88
+ semantic_rate_hz (`float`, *optional*, defaults to 49.9):
89
+ Semantic rate in Hertz.
90
+ min_eos_p (`float`, *optional*):
91
+ Minimum threshold of the probability of the EOS token for it to be sampled. This is an early stopping
92
+ strategy to mitigate potential unwanted generations at the end of a prompt. The original implementation
93
+ suggests a default value of 0.2.
94
+ """
95
+ super().__init__(
96
+ temperature=temperature,
97
+ do_sample=do_sample,
98
+ eos_token_id=eos_token_id,
99
+ renormalize_logits=renormalize_logits,
100
+ max_new_tokens=max_new_tokens,
101
+ output_scores=output_scores,
102
+ return_dict_in_generate=return_dict_in_generate,
103
+ output_hidden_states=output_hidden_states,
104
+ output_attentions=output_attentions,
105
+ **kwargs,
106
+ )
107
+
108
+ self.text_encoding_offset = text_encoding_offset
109
+ self.text_pad_token = text_pad_token
110
+ self.semantic_pad_token = eos_token_id
111
+ self.semantic_infer_token = semantic_infer_token
112
+ self.semantic_vocab_size = semantic_vocab_size
113
+ self.max_input_semantic_length = max_input_semantic_length
114
+ self.semantic_rate_hz = semantic_rate_hz
115
+ self.min_eos_p = min_eos_p
116
+
117
+
118
+ class BarkCoarseGenerationConfig(GenerationConfig):
119
+ model_type = "coarse_acoustics"
120
+
121
+ def __init__(
122
+ self,
123
+ renormalize_logits=True,
124
+ output_scores=False,
125
+ return_dict_in_generate=False,
126
+ output_hidden_states=False,
127
+ output_attentions=False,
128
+ temperature=1.0,
129
+ do_sample=False,
130
+ coarse_semantic_pad_token=12_048,
131
+ coarse_rate_hz=75,
132
+ n_coarse_codebooks=2,
133
+ coarse_infer_token=12_050,
134
+ max_coarse_input_length=256,
135
+ max_coarse_history: int = 630,
136
+ sliding_window_len: int = 60,
137
+ **kwargs,
138
+ ):
139
+ """Class that holds a generation configuration for [`BarkCoarseModel`].
140
+
141
+ This configuration inherit from [`GenerationConfig`] and can be used to control the model generation. Read the
142
+ documentation from [`GenerationConfig`] for more information.
143
+
144
+ Args:
145
+ renormalize_logits (`bool`, *optional*, defaults to `True`):
146
+ Whether to renormalize the logits after applying all the logits processors (including the
147
+ custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
148
+ score logits are normalized but some logit processors break the normalization.
149
+ output_scores (`bool`, *optional*, defaults to `False`):
150
+ Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
151
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
152
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
153
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
154
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
155
+ for more details.
156
+ output_attentions (`bool`, *optional*, defaults to `False`):
157
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
158
+ returned tensors for more details.
159
+ temperature (`float`, *optional*, defaults to 1.0):
160
+ The value used to modulate the next token probabilities.
161
+ do_sample (`bool`, *optional*, defaults to `False`):
162
+ Whether or not to use sampling ; use greedy decoding otherwise.
163
+ coarse_semantic_pad_token (`int`, *optional*, defaults to 12_048):
164
+ Coarse semantic pad token.
165
+ coarse_rate_hz (`int`, *optional*, defaults to 75):
166
+ Coarse rate in Hertz.
167
+ n_coarse_codebooks (`int`, *optional*, defaults to 2):
168
+ Number of coarse codebooks.
169
+ coarse_infer_token (`int`, *optional*, defaults to 12_050):
170
+ Coarse infer token.
171
+ max_coarse_input_length (`int`, *optional*, defaults to 256):
172
+ Max length of input coarse vector.
173
+ max_coarse_history (`int`, *optional*, defaults to 630):
174
+ Max length of the output of the coarse acoustics model used in the fine generation step.
175
+ sliding_window_len (`int`, *optional*, defaults to 60):
176
+ The coarse generation step uses a sliding window to generate raw audio.
177
+ """
178
+ super().__init__(
179
+ temperature=temperature,
180
+ do_sample=do_sample,
181
+ renormalize_logits=renormalize_logits,
182
+ output_scores=output_scores,
183
+ return_dict_in_generate=return_dict_in_generate,
184
+ output_hidden_states=output_hidden_states,
185
+ output_attentions=output_attentions,
186
+ **kwargs,
187
+ )
188
+
189
+ self.coarse_semantic_pad_token = coarse_semantic_pad_token
190
+ self.coarse_rate_hz = coarse_rate_hz
191
+ self.n_coarse_codebooks = n_coarse_codebooks
192
+ self.coarse_infer_token = coarse_infer_token
193
+ self.max_coarse_input_length = max_coarse_input_length
194
+ self.max_coarse_history = max_coarse_history
195
+ self.sliding_window_len = sliding_window_len
196
+
197
+
198
+ class BarkFineGenerationConfig(GenerationConfig):
199
+ model_type = "fine_acoustics"
200
+
201
+ def __init__(
202
+ self,
203
+ temperature=1.0,
204
+ max_fine_history_length=512,
205
+ max_fine_input_length=1024,
206
+ n_fine_codebooks=8,
207
+ **kwargs,
208
+ ):
209
+ """Class that holds a generation configuration for [`BarkFineModel`].
210
+
211
+ [`BarkFineModel`] is an autoencoder model, so should not usually be used for generation. However, under the
212
+ hood, it uses `temperature` when used by [`BarkModel`]
213
+
214
+ This configuration inherit from [`GenerationConfig`] and can be used to control the model generation. Read the
215
+ documentation from [`GenerationConfig`] for more information.
216
+
217
+ Args:
218
+ temperature (`float`, *optional*):
219
+ The value used to modulate the next token probabilities.
220
+ max_fine_history_length (`int`, *optional*, defaults to 512):
221
+ Max length of the fine history vector.
222
+ max_fine_input_length (`int`, *optional*, defaults to 1024):
223
+ Max length of fine input vector.
224
+ n_fine_codebooks (`int`, *optional*, defaults to 8):
225
+ Number of codebooks used.
226
+ """
227
+ super().__init__(temperature=temperature)
228
+
229
+ self.max_fine_history_length = max_fine_history_length
230
+ self.max_fine_input_length = max_fine_input_length
231
+ self.n_fine_codebooks = n_fine_codebooks
232
+
233
+ def validate(self, **kwargs):
234
+ """
235
+ Overrides GenerationConfig.validate because BarkFineGenerationConfig don't use any parameters outside
236
+ temperature.
237
+ """
238
+ pass
239
+
240
+
241
+ class BarkGenerationConfig(GenerationConfig):
242
+ model_type = "bark"
243
+
244
+ # TODO (joao): nested from_dict
245
+
246
+ def __init__(
247
+ self,
248
+ semantic_config: Optional[dict] = None,
249
+ coarse_acoustics_config: Optional[dict] = None,
250
+ fine_acoustics_config: Optional[dict] = None,
251
+ sample_rate=24_000,
252
+ codebook_size=1024,
253
+ **kwargs,
254
+ ):
255
+ """Class that holds a generation configuration for [`BarkModel`].
256
+
257
+ The [`BarkModel`] does not have a `generate` method, but uses this class to generate speeches with a nested
258
+ [`BarkGenerationConfig`] which uses [`BarkSemanticGenerationConfig`], [`BarkCoarseGenerationConfig`],
259
+ [`BarkFineGenerationConfig`].
260
+
261
+ This configuration inherit from [`GenerationConfig`] and can be used to control the model generation. Read the
262
+ documentation from [`GenerationConfig`] for more information.
263
+
264
+ Args:
265
+ semantic_config (`Dict`, *optional*):
266
+ Semantic generation configuration.
267
+ coarse_acoustics_config (`Dict`, *optional*):
268
+ Coarse generation configuration.
269
+ fine_acoustics_config (`Dict`, *optional*):
270
+ Fine generation configuration.
271
+ sample_rate (`int`, *optional*, defaults to 24_000):
272
+ Sample rate.
273
+ codebook_size (`int`, *optional*, defaults to 1024):
274
+ Vector length for each codebook.
275
+ """
276
+ if semantic_config is None:
277
+ semantic_config = {}
278
+ logger.info("semantic_config is None. initializing the semantic model with default values.")
279
+
280
+ if coarse_acoustics_config is None:
281
+ coarse_acoustics_config = {}
282
+ logger.info("coarse_acoustics_config is None. initializing the coarse model with default values.")
283
+
284
+ if fine_acoustics_config is None:
285
+ fine_acoustics_config = {}
286
+ logger.info("fine_acoustics_config is None. initializing the fine model with default values.")
287
+
288
+ self.semantic_config = BarkSemanticGenerationConfig(**semantic_config)
289
+ self.coarse_acoustics_config = BarkCoarseGenerationConfig(**coarse_acoustics_config)
290
+ self.fine_acoustics_config = BarkFineGenerationConfig(**fine_acoustics_config)
291
+
292
+ self.sample_rate = sample_rate
293
+ self.codebook_size = codebook_size
294
+
295
+ @classmethod
296
+ def from_sub_model_configs(
297
+ cls,
298
+ semantic_config: BarkSemanticGenerationConfig,
299
+ coarse_acoustics_config: BarkCoarseGenerationConfig,
300
+ fine_acoustics_config: BarkFineGenerationConfig,
301
+ **kwargs,
302
+ ):
303
+ r"""
304
+ Instantiate a [`BarkGenerationConfig`] (or a derived class) from bark sub-models generation configuration.
305
+
306
+ Returns:
307
+ [`BarkGenerationConfig`]: An instance of a configuration object
308
+ """
309
+ return cls(
310
+ semantic_config=semantic_config.to_dict(),
311
+ coarse_acoustics_config=coarse_acoustics_config.to_dict(),
312
+ fine_acoustics_config=fine_acoustics_config.to_dict(),
313
+ **kwargs,
314
+ )
315
+
316
+ def to_dict(self):
317
+ """
318
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
319
+
320
+ Returns:
321
+ `dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
322
+ """
323
+ output = copy.deepcopy(self.__dict__)
324
+
325
+ output["semantic_config"] = self.semantic_config.to_dict()
326
+ output["coarse_acoustics_config"] = self.coarse_acoustics_config.to_dict()
327
+ output["fine_acoustics_config"] = self.fine_acoustics_config.to_dict()
328
+
329
+ output["model_type"] = self.__class__.model_type
330
+ return output
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/modeling_bark.py ADDED
@@ -0,0 +1,1628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch BARK model."""
16
+
17
+ import math
18
+ import warnings
19
+ from typing import Optional, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torch import nn
24
+ from torch.nn import functional as F
25
+
26
+ from ...cache_utils import Cache, DynamicCache
27
+ from ...generation import GenerationMixin
28
+ from ...generation.logits_process import (
29
+ AlternatingCodebooksLogitsProcessor,
30
+ BarkEosPrioritizerLogitsProcessor,
31
+ SuppressTokensLogitsProcessor,
32
+ )
33
+ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
34
+ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
35
+ from ...modeling_layers import GradientCheckpointingLayer
36
+ from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
37
+ from ...modeling_utils import PreTrainedModel, get_parameter_device
38
+ from ...utils import (
39
+ auto_docstring,
40
+ is_accelerate_available,
41
+ is_torch_accelerator_available,
42
+ logging,
43
+ )
44
+ from ..auto import AutoModel
45
+ from .configuration_bark import (
46
+ BarkCoarseConfig,
47
+ BarkConfig,
48
+ BarkFineConfig,
49
+ BarkSemanticConfig,
50
+ BarkSubModelConfig,
51
+ )
52
+ from .generation_configuration_bark import (
53
+ BarkCoarseGenerationConfig,
54
+ BarkFineGenerationConfig,
55
+ BarkSemanticGenerationConfig,
56
+ )
57
+
58
+
59
+ if is_flash_attn_available():
60
+ from ...modeling_flash_attention_utils import _flash_attention_forward
61
+
62
+
63
+ logger = logging.get_logger(__name__)
64
+
65
+
66
+ class BarkSelfAttention(nn.Module):
67
+ # adapted from GPTNeoSelfAttention and Bark code
68
+ # BarkSelfAttention can have two attention type, i.e full attention or causal attention
69
+
70
+ def __init__(self, config, is_causal=False, layer_idx=None):
71
+ super().__init__()
72
+
73
+ # regularization
74
+ self.dropout = config.dropout
75
+ self.attn_dropout = nn.Dropout(config.dropout)
76
+ self.resid_dropout = nn.Dropout(config.dropout)
77
+
78
+ self.embed_dim = config.hidden_size
79
+ self.num_heads = config.num_heads
80
+ self.head_dim = self.embed_dim // self.num_heads
81
+
82
+ if config.hidden_size % config.num_heads != 0:
83
+ raise ValueError(
84
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
85
+ f" {self.num_heads})."
86
+ )
87
+
88
+ # key, query, value projections for all heads, but in a batch
89
+ self.att_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.bias)
90
+ # output projection
91
+ self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.bias)
92
+
93
+ self.is_causal = is_causal
94
+ self.layer_idx = layer_idx
95
+ if is_causal:
96
+ block_size = config.block_size
97
+ bias = torch.tril(torch.ones((block_size, block_size), dtype=bool)).view(1, 1, block_size, block_size)
98
+ self.register_buffer("bias", bias)
99
+
100
+ # Copied from transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoSelfAttention._split_heads
101
+ def _split_heads(self, tensor, num_heads, attn_head_size):
102
+ """
103
+ Splits hidden_size dim into attn_head_size and num_heads
104
+ """
105
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
106
+ tensor = tensor.view(new_shape)
107
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
108
+
109
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
110
+ """
111
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
112
+ """
113
+
114
+ # re-assemble all head outputs side by side
115
+ # (batch, num_heads, seq_len, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
116
+ tensor = tensor.transpose(1, 2).contiguous()
117
+ tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
118
+
119
+ return tensor
120
+
121
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
122
+ # unlike GPTNeo's SelfAttention, divide by the square root of the dimension of the query and the key
123
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * (1.0 / math.sqrt(self.head_dim))
124
+
125
+ if self.is_causal:
126
+ query_length, key_length = query.size(-2), key.size(-2)
127
+
128
+ # fill the upper left part of the attention weights with inf
129
+ attn_weights = attn_weights.masked_fill(
130
+ self.bias[:, :, key_length - query_length : key_length, :key_length] == 0,
131
+ torch.finfo(attn_weights.dtype).min,
132
+ )
133
+
134
+ if attention_mask is not None:
135
+ # Apply the attention mask
136
+ attn_weights = attn_weights + attention_mask
137
+
138
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
139
+ attn_weights = attn_weights.to(value.dtype)
140
+ attn_weights = self.attn_dropout(attn_weights)
141
+
142
+ # Mask heads if we want to
143
+ if head_mask is not None:
144
+ attn_weights = attn_weights * head_mask
145
+
146
+ # (batch, num_heads, seq_len, seq_len) x (batch, num_heads, seq_len, attn_head_size)
147
+ # -> (batch, num_heads, seq_len, attn_head_size)
148
+ attn_output = torch.matmul(attn_weights, value)
149
+
150
+ return attn_output, attn_weights
151
+
152
+ def forward(
153
+ self,
154
+ hidden_states,
155
+ attention_mask=None,
156
+ past_key_values=None,
157
+ head_mask=None,
158
+ use_cache=False,
159
+ output_attentions=False,
160
+ cache_position=None,
161
+ ):
162
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
163
+ query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
164
+
165
+ query = self._split_heads(query, self.num_heads, self.head_dim)
166
+ key = self._split_heads(key, self.num_heads, self.head_dim)
167
+ value = self._split_heads(value, self.num_heads, self.head_dim)
168
+
169
+ if past_key_values is not None:
170
+ key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position})
171
+
172
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
173
+
174
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
175
+ attn_output = self.out_proj(attn_output)
176
+ attn_output = self.resid_dropout(attn_output)
177
+
178
+ return attn_output, attn_weights
179
+
180
+
181
+ class BarkSelfFlashAttention2(BarkSelfAttention):
182
+ """
183
+ Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays
184
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
185
+ flash attention and deal with padding tokens in case the input contains any of them.
186
+ """
187
+
188
+ def __init__(self, *args, **kwargs):
189
+ super().__init__(*args, **kwargs)
190
+
191
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
192
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
193
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
194
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
195
+
196
+ def _split_heads(self, tensor, num_heads, attn_head_size):
197
+ """
198
+ Splits hidden_size dim into attn_head_size and num_heads
199
+ """
200
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
201
+ tensor = tensor.view(new_shape)
202
+ # Flash attention requires the input to have the shape
203
+ # batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features)
204
+ return tensor
205
+
206
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
207
+ """
208
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
209
+ """
210
+ # re-assemble all head outputs side by side
211
+ # (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
212
+ tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
213
+ return tensor
214
+
215
+ def forward(
216
+ self,
217
+ hidden_states,
218
+ attention_mask=None,
219
+ past_key_values=None,
220
+ head_mask=None,
221
+ use_cache=False,
222
+ output_attentions=False,
223
+ cache_position=None,
224
+ ):
225
+ batch_size, query_len, _ = hidden_states.size()
226
+
227
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
228
+ query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
229
+
230
+ query = self._split_heads(query, self.num_heads, self.head_dim)
231
+ key = self._split_heads(key, self.num_heads, self.head_dim)
232
+ value = self._split_heads(value, self.num_heads, self.head_dim)
233
+
234
+ if past_key_values is not None:
235
+ key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position})
236
+
237
+ attn_output = _flash_attention_forward(
238
+ query,
239
+ key,
240
+ value,
241
+ attention_mask,
242
+ query_len,
243
+ dropout=self.dropout if self.training else 0.0,
244
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
245
+ is_causal=self.is_causal,
246
+ )
247
+
248
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
249
+ attn_output = self.out_proj(attn_output)
250
+ attn_output = self.resid_dropout(attn_output)
251
+
252
+ return attn_output, None
253
+
254
+
255
+ BARK_ATTENTION_CLASSES = {
256
+ "eager": BarkSelfAttention,
257
+ "flash_attention_2": BarkSelfFlashAttention2,
258
+ }
259
+
260
+
261
+ class BarkMLP(nn.Module):
262
+ def __init__(self, config):
263
+ super().__init__()
264
+ self.in_proj = nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=config.bias)
265
+ self.out_proj = nn.Linear(4 * config.hidden_size, config.hidden_size, bias=config.bias)
266
+ self.dropout = nn.Dropout(config.dropout)
267
+ self.gelu = nn.GELU()
268
+
269
+ def forward(self, hidden_states):
270
+ hidden_states = self.in_proj(hidden_states)
271
+ hidden_states = self.gelu(hidden_states)
272
+ hidden_states = self.out_proj(hidden_states)
273
+ hidden_states = self.dropout(hidden_states)
274
+ return hidden_states
275
+
276
+
277
+ class BarkBlock(GradientCheckpointingLayer):
278
+ def __init__(self, config, is_causal=False, layer_idx=None):
279
+ super().__init__()
280
+
281
+ if is_causal:
282
+ # if causal, the layerNorm bias is optional to stick with Bark choice of leaving optional bias
283
+ # in AutoRegressive models (corresponding to the "Text" and the "Coarse" modules)
284
+ self.layernorm_1 = nn.LayerNorm(config.hidden_size, bias=config.bias)
285
+ self.layernorm_2 = nn.LayerNorm(config.hidden_size, bias=config.bias)
286
+ else:
287
+ self.layernorm_1 = nn.LayerNorm(config.hidden_size)
288
+ self.layernorm_2 = nn.LayerNorm(config.hidden_size)
289
+
290
+ self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation](
291
+ config, is_causal=is_causal, layer_idx=layer_idx
292
+ )
293
+
294
+ self.mlp = BarkMLP(config)
295
+
296
+ def forward(
297
+ self,
298
+ hidden_states,
299
+ past_key_values=None,
300
+ attention_mask=None,
301
+ head_mask=None,
302
+ use_cache=False,
303
+ output_attentions=False,
304
+ cache_position=None,
305
+ ):
306
+ intermediary_hidden_states = self.layernorm_1(hidden_states)
307
+
308
+ attn_outputs = self.attn(
309
+ intermediary_hidden_states,
310
+ past_key_values=past_key_values,
311
+ attention_mask=attention_mask,
312
+ head_mask=head_mask,
313
+ use_cache=use_cache,
314
+ output_attentions=output_attentions,
315
+ cache_position=cache_position,
316
+ )
317
+
318
+ attn_output = attn_outputs[0] # output_attn: output, present_key_values, (attn_weights)
319
+ outputs = attn_outputs[1:]
320
+
321
+ intermediary_hidden_states = hidden_states + attn_output
322
+ intermediary_hidden_states = intermediary_hidden_states + self.mlp(
323
+ self.layernorm_2(intermediary_hidden_states)
324
+ )
325
+
326
+ return (intermediary_hidden_states,) + outputs
327
+
328
+
329
+ @auto_docstring
330
+ class BarkPreTrainedModel(PreTrainedModel):
331
+ config: BarkConfig
332
+ supports_gradient_checkpointing = False
333
+ _supports_flash_attn = True
334
+
335
+ def _init_weights(self, module):
336
+ """Initialize the weights."""
337
+ if isinstance(module, (nn.Linear,)):
338
+ # Slightly different from the TF version which uses truncated_normal for initialization
339
+ # cf https://github.com/pytorch/pytorch/pull/5617
340
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
341
+ if module.bias is not None:
342
+ module.bias.data.zero_()
343
+ elif isinstance(module, nn.Embedding):
344
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
345
+ if module.padding_idx is not None:
346
+ module.weight.data[module.padding_idx].zero_()
347
+ elif isinstance(module, nn.LayerNorm):
348
+ module.bias.data.zero_()
349
+ module.weight.data.fill_(1.0)
350
+
351
+ def __init__(self, *inputs, **kwargs):
352
+ super().__init__(*inputs, **kwargs)
353
+
354
+ @property
355
+ def device(self) -> torch.device:
356
+ """
357
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
358
+ device).
359
+ """
360
+
361
+ # if has _hf_hook, has been offloaded so the device has to be found in the hook
362
+ if not hasattr(self, "_hf_hook"):
363
+ return get_parameter_device(self)
364
+ for module in self.modules():
365
+ if (
366
+ hasattr(module, "_hf_hook")
367
+ and hasattr(module._hf_hook, "execution_device")
368
+ and module._hf_hook.execution_device is not None
369
+ ):
370
+ return torch.device(module._hf_hook.execution_device)
371
+
372
+ return get_parameter_device(self)
373
+
374
+
375
+ # GPT2-like autoregressive model
376
+ class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
377
+ config: BarkSubModelConfig
378
+
379
+ def __init__(self, config):
380
+ super().__init__(config)
381
+ self.config = config
382
+
383
+ # initialize as an autoregressive GPT-like model
384
+ self.input_embeds_layer = nn.Embedding(config.input_vocab_size, config.hidden_size)
385
+ self.position_embeds_layer = nn.Embedding(config.block_size, config.hidden_size)
386
+
387
+ self.drop = nn.Dropout(config.dropout)
388
+
389
+ self.layers = nn.ModuleList([BarkBlock(config, is_causal=True, layer_idx=i) for i in range(config.num_layers)])
390
+
391
+ self.layernorm_final = nn.LayerNorm(config.hidden_size, bias=config.bias)
392
+
393
+ self.lm_head = nn.Linear(config.hidden_size, config.output_vocab_size, bias=False)
394
+ self.gradient_checkpointing = False
395
+
396
+ # Initialize weights and apply final processing
397
+ self.post_init()
398
+
399
+ def get_output_embeddings(self):
400
+ # NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
401
+ # See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
402
+ return None
403
+
404
+ def get_input_embeddings(self):
405
+ return self.input_embeds_layer
406
+
407
+ def set_input_embeddings(self, new_embeddings):
408
+ self.input_embeds_layer = new_embeddings
409
+
410
+ def prepare_inputs_for_generation(
411
+ self,
412
+ input_ids,
413
+ attention_mask=None,
414
+ input_embeds=None,
415
+ past_key_values=None,
416
+ position_ids=None,
417
+ use_cache=None,
418
+ cache_position=None,
419
+ **kwargs,
420
+ ):
421
+ # Overwritten -- bark uses `input_embeds` not `inputS_embeds`
422
+
423
+ model_inputs = super().prepare_inputs_for_generation(
424
+ input_ids,
425
+ attention_mask=attention_mask,
426
+ inputs_embeds=input_embeds,
427
+ past_key_values=past_key_values,
428
+ position_ids=position_ids,
429
+ use_cache=use_cache,
430
+ cache_position=cache_position,
431
+ **kwargs,
432
+ )
433
+ model_inputs["input_embeds"] = model_inputs.pop("inputs_embeds", None)
434
+ return model_inputs
435
+
436
+ @auto_docstring
437
+ def forward(
438
+ self,
439
+ input_ids: Optional[torch.Tensor] = None,
440
+ past_key_values: Optional[Cache] = None,
441
+ attention_mask: Optional[torch.Tensor] = None,
442
+ position_ids: Optional[torch.Tensor] = None,
443
+ head_mask: Optional[torch.Tensor] = None,
444
+ labels: Optional[torch.LongTensor] = None,
445
+ input_embeds: Optional[torch.Tensor] = None,
446
+ use_cache: Optional[bool] = None,
447
+ output_attentions: Optional[bool] = None,
448
+ output_hidden_states: Optional[bool] = None,
449
+ return_dict: Optional[bool] = None,
450
+ cache_position: Optional[torch.Tensor] = None,
451
+ ) -> Union[tuple[torch.Tensor], CausalLMOutputWithPast]:
452
+ r"""
453
+ input_embeds (`torch.FloatTensor` of shape `(batch_size, input_sequence_length, hidden_size)`, *optional*):
454
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
455
+ Here, due to `Bark` particularities, if `past_key_values` is used, `input_embeds` will be ignored and you
456
+ have to use `input_ids`. If `past_key_values` is not used and `use_cache` is set to `True`, `input_embeds`
457
+ is used in priority instead of `input_ids`.
458
+ """
459
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
460
+ output_hidden_states = (
461
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
462
+ )
463
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
464
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
465
+
466
+ loss = None
467
+ if labels is not None:
468
+ raise NotImplementedError(
469
+ "Training is not implemented yet for Bark - ensure you do not pass `labels` to the model."
470
+ )
471
+
472
+ # Verify if input_embeds already exists
473
+ # then compute embeddings.
474
+ if input_ids is not None and input_embeds is not None:
475
+ raise ValueError("You cannot specify both input_ids and input_embeds at the same time")
476
+ elif input_embeds is not None and past_key_values is None:
477
+ # we want to return the input_embeds in priority so that it is in line with a weird hack
478
+ # of Bark which concatenate two bits of the input_embeds on the first forward pass of the semantic model
479
+ pass
480
+ elif input_ids is not None:
481
+ input_embeds = self.input_embeds_layer(input_ids) # token embeddings of shape (b, t, n_embd)
482
+ elif input_embeds is not None:
483
+ pass
484
+ else:
485
+ raise ValueError("You have to specify either input_ids or input_embeds")
486
+
487
+ input_shape = input_embeds.size()[:-1]
488
+ batch_size = input_embeds.shape[0]
489
+ seq_length = input_shape[-1]
490
+
491
+ device = input_ids.device if input_ids is not None else input_embeds.device
492
+
493
+ if self.gradient_checkpointing and self.training:
494
+ if use_cache:
495
+ logger.warning_once(
496
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
497
+ )
498
+ use_cache = False
499
+
500
+ if use_cache and past_key_values is None:
501
+ past_key_values = DynamicCache(config=self.config)
502
+ if use_cache and isinstance(past_key_values, tuple):
503
+ logger.warning_once(
504
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
505
+ "You should pass an instance of `DynamicCache` instead, e.g. "
506
+ "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
507
+ )
508
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
509
+
510
+ past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
511
+
512
+ if position_ids is None:
513
+ position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
514
+ position_ids = position_ids.unsqueeze(0) # shape (1, seq_length)
515
+
516
+ position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd)
517
+
518
+ # Attention mask.
519
+ if attention_mask is not None:
520
+ if batch_size <= 0:
521
+ raise ValueError("batch_size has to be defined and > 0")
522
+ if self.config._attn_implementation == "flash_attention_2":
523
+ attention_mask = attention_mask if 0 in attention_mask else None
524
+ else:
525
+ attention_mask = attention_mask.view(batch_size, -1)
526
+ # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
527
+ # from_seq_length is 1 to easily broadcast
528
+ attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
529
+
530
+ # Prepare head mask if needed
531
+ # 1.0 in head_mask indicate we keep the head
532
+ # attention_probs has shape bsz x num_heads x N x N
533
+ # head_mask has shape num_layers x batch x num_heads x N x N
534
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
535
+
536
+ hidden_states = self.drop(input_embeds + position_embeds)
537
+ output_shape = input_shape + (hidden_states.size(-1),)
538
+
539
+ all_self_attentions = () if output_attentions else None
540
+ all_hidden_states = () if output_hidden_states else None
541
+
542
+ for i, block in enumerate(self.layers):
543
+ if output_hidden_states:
544
+ all_hidden_states = all_hidden_states + (hidden_states,)
545
+
546
+ outputs = block(
547
+ hidden_states,
548
+ past_key_values=past_key_values,
549
+ attention_mask=attention_mask,
550
+ head_mask=head_mask[i],
551
+ use_cache=use_cache,
552
+ output_attentions=output_attentions,
553
+ cache_position=cache_position,
554
+ )
555
+
556
+ hidden_states = outputs[0]
557
+
558
+ if output_attentions:
559
+ all_self_attentions = all_self_attentions + (outputs[1],)
560
+
561
+ hidden_states = self.layernorm_final(hidden_states)
562
+
563
+ hidden_states = hidden_states.view(output_shape)
564
+
565
+ # Add last hidden state
566
+ if output_hidden_states:
567
+ all_hidden_states = all_hidden_states + (hidden_states,)
568
+
569
+ logits = self.lm_head(hidden_states)
570
+
571
+ if not return_dict:
572
+ return tuple(
573
+ v for v in [None, logits, past_key_values, all_hidden_states, all_self_attentions] if v is not None
574
+ )
575
+
576
+ return CausalLMOutputWithPast(
577
+ loss=loss,
578
+ logits=logits,
579
+ past_key_values=past_key_values,
580
+ hidden_states=all_hidden_states,
581
+ attentions=all_self_attentions,
582
+ )
583
+
584
+
585
+ @auto_docstring(
586
+ custom_intro="""
587
+ Bark semantic (or text) model. It shares the same architecture as the coarse model.
588
+ It is a GPT-2 like autoregressive model with a language modeling head on top.
589
+ """
590
+ )
591
+ class BarkSemanticModel(BarkCausalModel):
592
+ base_model_prefix = "semantic"
593
+ config: BarkSemanticConfig
594
+
595
+ def generate(
596
+ self,
597
+ input_ids: torch.Tensor,
598
+ semantic_generation_config: Optional[BarkSemanticGenerationConfig] = None,
599
+ history_prompt: Optional[dict[str, torch.Tensor]] = None,
600
+ attention_mask: Optional[torch.Tensor] = None,
601
+ **kwargs,
602
+ ) -> torch.LongTensor:
603
+ """
604
+ Generates text semantic tokens from an input prompt and an additional optional `Bark` speaker prompt.
605
+
606
+ Args:
607
+ input_ids (`Optional[torch.Tensor]` of shape (batch_size, seq_len), *optional*):
608
+ Input ids, i.e tokenized input sentences. Will be truncated up to
609
+ semantic_generation_config.max_input_semantic_length tokens. Note that the output audios will be as
610
+ long as the longest generation among the batch.
611
+ semantic_generation_config (`BarkSemanticGenerationConfig`):
612
+ Generation config indicating how to generate the semantic tokens.
613
+ history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
614
+ Optional `Bark` speaker prompt.
615
+ attention_mask (`Optional[torch.Tensor]`, *optional*):
616
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
617
+
618
+ - 1 for tokens that are **not masked**,
619
+ - 0 for tokens that are **masked**.
620
+
621
+ [What are attention masks?](../glossary#attention-mask)
622
+ Returns:
623
+ torch.LongTensor: Output semantic tokens.
624
+ """
625
+ if semantic_generation_config is None:
626
+ raise ValueError("`semantic_generation_config` has to be provided")
627
+
628
+ batch_size = input_ids.shape[0]
629
+
630
+ max_input_semantic_length = semantic_generation_config.max_input_semantic_length
631
+
632
+ input_ids = input_ids + semantic_generation_config.text_encoding_offset
633
+
634
+ if attention_mask is not None:
635
+ input_ids = input_ids.masked_fill((1 - attention_mask).bool(), semantic_generation_config.text_pad_token)
636
+
637
+ if history_prompt is not None:
638
+ semantic_history = history_prompt["semantic_prompt"][-max_input_semantic_length:]
639
+ semantic_history = nn.functional.pad(
640
+ semantic_history,
641
+ (0, max_input_semantic_length - len(semantic_history)),
642
+ value=semantic_generation_config.semantic_pad_token,
643
+ mode="constant",
644
+ )
645
+ else:
646
+ semantic_history = torch.tensor(
647
+ [semantic_generation_config.semantic_pad_token] * max_input_semantic_length, dtype=torch.int
648
+ ).to(self.device)
649
+
650
+ semantic_history = torch.repeat_interleave(semantic_history[None], batch_size, dim=0)
651
+
652
+ infer_array = torch.tensor(
653
+ [[semantic_generation_config.semantic_infer_token]] * batch_size, dtype=torch.int
654
+ ).to(self.device)
655
+
656
+ input_embeds = torch.cat(
657
+ [
658
+ self.input_embeds_layer(input_ids[:, :max_input_semantic_length])
659
+ + self.input_embeds_layer(semantic_history[:, : max_input_semantic_length + 1]),
660
+ self.input_embeds_layer(infer_array),
661
+ ],
662
+ dim=1,
663
+ )
664
+
665
+ tokens_to_suppress = list(
666
+ range(semantic_generation_config.semantic_vocab_size, semantic_generation_config.semantic_pad_token)
667
+ )
668
+ tokens_to_suppress.extend(
669
+ list(range(semantic_generation_config.semantic_pad_token + 1, self.config.output_vocab_size))
670
+ )
671
+
672
+ suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress, device=input_ids.device)
673
+
674
+ min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p)
675
+ early_stopping_logits_processor = BarkEosPrioritizerLogitsProcessor(
676
+ eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p, device=input_ids.device
677
+ )
678
+
679
+ # pass input_ids in order to stay consistent with the transformers generate method even though it is not used
680
+ # (except to get the input seq_len - that's why we keep the first 257 tokens)
681
+ semantic_output = super().generate(
682
+ torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int, device=self.device),
683
+ input_embeds=input_embeds,
684
+ logits_processor=[suppress_tokens_logits_processor, early_stopping_logits_processor],
685
+ generation_config=semantic_generation_config,
686
+ **kwargs,
687
+ ) # size: 10048
688
+
689
+ # take the generated semantic tokens
690
+ semantic_output = semantic_output[:, max_input_semantic_length + 1 :]
691
+
692
+ return semantic_output
693
+
694
+
695
+ @auto_docstring(
696
+ custom_intro="""
697
+ Bark coarse acoustics model.
698
+ It shares the same architecture as the semantic (or text) model. It is a GPT-2 like autoregressive model with a
699
+ language modeling head on top.
700
+ """
701
+ )
702
+ class BarkCoarseModel(BarkCausalModel):
703
+ base_model_prefix = "coarse_acoustics"
704
+ config: BarkCoarseConfig
705
+
706
+ def preprocess_histories(
707
+ self,
708
+ max_coarse_history: int,
709
+ semantic_to_coarse_ratio: int,
710
+ batch_size: int,
711
+ semantic_generation_config: int,
712
+ codebook_size: int,
713
+ history_prompt: Optional[dict[str, torch.Tensor]] = None,
714
+ ):
715
+ """
716
+ Preprocess the optional `Bark` speaker prompts before `self.generate`.
717
+
718
+ Args:
719
+ max_coarse_history (`int`):
720
+ Maximum size of coarse tokens used.
721
+ semantic_to_coarse_ratio (`int`):
722
+ Ratio of semantic to coarse frequency
723
+ batch_size (`int`):
724
+ Batch size, i.e the number of samples.
725
+ semantic_generation_config (`BarkSemanticGenerationConfig`):
726
+ Generation config indicating how to generate the semantic tokens.
727
+ codebook_size (`int`):
728
+ Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
729
+ history_prompt (`Optional[dict[str,torch.Tensor]]`):
730
+ Optional `Bark` speaker prompt.
731
+ Returns: Returns:
732
+ `tuple(torch.FloatTensor)`:
733
+ - **x_semantic_history** (`torch.FloatTensor` -- Processed semantic speaker prompt.
734
+ - **x_coarse_history** (`torch.FloatTensor`) -- Processed coarse speaker prompt.
735
+ """
736
+ if history_prompt is not None:
737
+ x_semantic_history = torch.repeat_interleave(history_prompt["semantic_prompt"][None], batch_size, dim=0)
738
+ # clone to avoid modifying history_prompt.coarse_prompt
739
+ x_coarse_history = history_prompt["coarse_prompt"].clone()
740
+
741
+ # offset x_coarse_history
742
+ if codebook_size is not None:
743
+ for n in range(1, x_coarse_history.shape[0]):
744
+ # offset
745
+ x_coarse_history[n, :] += codebook_size * n
746
+
747
+ # flatten x_coarse_history
748
+ x_coarse_history = torch.transpose(x_coarse_history, 0, 1).reshape(-1)
749
+
750
+ x_coarse_history = x_coarse_history + semantic_generation_config.semantic_vocab_size
751
+
752
+ x_coarse_history = torch.repeat_interleave(x_coarse_history[None], batch_size, dim=0)
753
+ # e.g: after SEMANTIC_VOCAB_SIZE (10000), 1024 tokens dedicated to first codebook, 1024 next tokens
754
+ # dedicated to second codebook.
755
+
756
+ max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
757
+ # trim histories correctly
758
+ n_semantic_hist_provided = min(
759
+ [
760
+ max_semantic_history,
761
+ x_semantic_history.shape[1] - x_semantic_history.shape[1] % 2,
762
+ int(np.floor(x_coarse_history.shape[1] / semantic_to_coarse_ratio)),
763
+ ]
764
+ )
765
+
766
+ n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio))
767
+
768
+ x_semantic_history = x_semantic_history[:, -n_semantic_hist_provided:].int()
769
+ x_coarse_history = x_coarse_history[:, -n_coarse_hist_provided:].int()
770
+ # bit of a hack for time alignment (sounds better) - from Bark original implementation
771
+ x_coarse_history = x_coarse_history[:, :-2]
772
+
773
+ else:
774
+ # shape: (batch_size, 0)
775
+ x_semantic_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device)
776
+ x_coarse_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device)
777
+
778
+ return x_semantic_history, x_coarse_history
779
+
780
+ def generate(
781
+ self,
782
+ semantic_output: torch.Tensor,
783
+ semantic_generation_config: Optional[BarkSemanticGenerationConfig] = None,
784
+ coarse_generation_config: Optional[BarkCoarseGenerationConfig] = None,
785
+ codebook_size: int = 1024,
786
+ history_prompt: Optional[dict[str, torch.Tensor]] = None,
787
+ return_output_lengths: Optional[bool] = None,
788
+ **kwargs,
789
+ ) -> Union[torch.LongTensor, tuple[torch.LongTensor, torch.LongTensor]]:
790
+ """
791
+ Generates coarse acoustics tokens from input text semantic tokens and an additional optional `Bark` speaker
792
+ prompt.
793
+
794
+ Args:
795
+ semantic_output (`torch.Tensor` of shape (batch_size, seq_len), *optional*):
796
+ Input text semantic ids, i.e the output of `BarkSemanticModel.generate`.
797
+ semantic_generation_config (`BarkSemanticGenerationConfig`):
798
+ Generation config indicating how to generate the semantic tokens.
799
+ coarse_generation_config (`BarkCoarseGenerationConfig`):
800
+ Generation config indicating how to generate the coarse tokens.
801
+ codebook_size (`int`, *optional*, defaults to 1024):
802
+ Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
803
+ history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
804
+ Optional `Bark` speaker prompt.
805
+ return_output_lengths (`bool`, *optional*):
806
+ Whether or not to return the output lengths. Useful when batching.
807
+ Returns:
808
+ By default:
809
+ torch.LongTensor: Output coarse acoustics tokens.
810
+ If `return_output_lengths=True`:
811
+ `Tuple(torch.Tensor, torch.Tensor): The output coarse acoustics tokens, and the length of each sample
812
+ of the batch.
813
+ """
814
+
815
+ if semantic_generation_config is None:
816
+ raise ValueError("`semantic_generation_config` has to be provided")
817
+
818
+ if coarse_generation_config is None:
819
+ raise ValueError("`coarse_generation_config` has to be provided")
820
+
821
+ max_coarse_input_length = coarse_generation_config.max_coarse_input_length
822
+ max_coarse_history = coarse_generation_config.max_coarse_history
823
+ sliding_window_len = coarse_generation_config.sliding_window_len
824
+
825
+ # replace semantic_pad_token (eos_tok and pad_tok here) with coarse_semantic_pad_token i.e the pad_token
826
+ # used in the next model
827
+ semantic_output.masked_fill_(
828
+ semantic_output == semantic_generation_config.semantic_pad_token,
829
+ coarse_generation_config.coarse_semantic_pad_token,
830
+ )
831
+
832
+ semantic_to_coarse_ratio = (
833
+ coarse_generation_config.coarse_rate_hz
834
+ / semantic_generation_config.semantic_rate_hz
835
+ * coarse_generation_config.n_coarse_codebooks
836
+ )
837
+ max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
838
+
839
+ output_lengths = (semantic_output != coarse_generation_config.coarse_semantic_pad_token).sum(1)
840
+ output_lengths = torch.floor(
841
+ output_lengths * semantic_to_coarse_ratio / coarse_generation_config.n_coarse_codebooks
842
+ )
843
+ output_lengths = torch.round(output_lengths * coarse_generation_config.n_coarse_codebooks).int()
844
+
845
+ max_generated_len = torch.max(output_lengths).item()
846
+
847
+ batch_size = semantic_output.shape[0]
848
+
849
+ x_semantic_history, x_coarse = self.preprocess_histories(
850
+ history_prompt=history_prompt,
851
+ max_coarse_history=max_coarse_history,
852
+ semantic_to_coarse_ratio=semantic_to_coarse_ratio,
853
+ batch_size=batch_size,
854
+ semantic_generation_config=semantic_generation_config,
855
+ codebook_size=codebook_size,
856
+ )
857
+ base_semantic_idx = x_semantic_history.shape[1]
858
+
859
+ semantic_output = torch.hstack([x_semantic_history, semantic_output])
860
+
861
+ n_window_steps = int(np.ceil(max_generated_len / sliding_window_len))
862
+
863
+ total_generated_len = 0
864
+
865
+ len_coarse_history = x_coarse.shape[1]
866
+
867
+ for _ in range(n_window_steps):
868
+ semantic_idx = base_semantic_idx + int(round(total_generated_len / semantic_to_coarse_ratio))
869
+
870
+ # pad from right side
871
+ input_coarse = semantic_output[:, np.max([0, semantic_idx - max_semantic_history]) :]
872
+ input_coarse = input_coarse[:, :max_coarse_input_length]
873
+ input_coarse = F.pad(
874
+ input_coarse,
875
+ (0, max_coarse_input_length - input_coarse.shape[-1]),
876
+ "constant",
877
+ coarse_generation_config.coarse_semantic_pad_token,
878
+ )
879
+
880
+ input_coarse = torch.hstack(
881
+ [
882
+ input_coarse,
883
+ torch.tensor([[coarse_generation_config.coarse_infer_token]] * batch_size, device=self.device),
884
+ x_coarse[:, -max_coarse_history:],
885
+ ]
886
+ )
887
+
888
+ alternatingLogitsProcessor = AlternatingCodebooksLogitsProcessor(
889
+ input_coarse.shape[1],
890
+ semantic_generation_config.semantic_vocab_size,
891
+ codebook_size,
892
+ )
893
+
894
+ output_coarse = super().generate(
895
+ input_coarse,
896
+ logits_processor=[alternatingLogitsProcessor],
897
+ max_new_tokens=min(sliding_window_len, max_generated_len - total_generated_len),
898
+ generation_config=coarse_generation_config,
899
+ **kwargs,
900
+ )
901
+
902
+ input_coarse_len = input_coarse.shape[1]
903
+
904
+ x_coarse = torch.hstack([x_coarse, output_coarse[:, input_coarse_len:]])
905
+ total_generated_len = x_coarse.shape[1] - len_coarse_history
906
+
907
+ del output_coarse
908
+
909
+ coarse_output = x_coarse[:, len_coarse_history:]
910
+
911
+ if return_output_lengths:
912
+ return coarse_output, output_lengths
913
+
914
+ return coarse_output
915
+
916
+
917
+ @auto_docstring(
918
+ custom_intro="""
919
+ Bark fine acoustics model. It is a non-causal GPT-like model with `config.n_codes_total` embedding layers and
920
+ language modeling heads, one for each codebook.
921
+ """
922
+ )
923
+ class BarkFineModel(BarkPreTrainedModel):
924
+ base_model_prefix = "fine_acoustics"
925
+ config: BarkFineConfig
926
+ main_input_name = "codebook_idx"
927
+
928
+ def __init__(self, config):
929
+ # non-causal gpt-like model with one embedding layer and one lm_head for each codebook of Encodec
930
+ super().__init__(config)
931
+ self.config = config
932
+
933
+ # initialize a modified non causal GPT-like model
934
+ # note that for there is one embedding layer and one lm_head for each codebook of Encodec
935
+ self.input_embeds_layers = nn.ModuleList(
936
+ [nn.Embedding(config.input_vocab_size, config.hidden_size) for _ in range(config.n_codes_total)]
937
+ )
938
+ self.position_embeds_layer = nn.Embedding(config.block_size, config.hidden_size)
939
+
940
+ self.drop = nn.Dropout(config.dropout)
941
+
942
+ self.layers = nn.ModuleList(
943
+ [BarkBlock(config, is_causal=False, layer_idx=i) for i in range(config.num_layers)]
944
+ )
945
+
946
+ self.layernorm_final = nn.LayerNorm(config.hidden_size)
947
+
948
+ self.lm_heads = nn.ModuleList(
949
+ [
950
+ nn.Linear(config.hidden_size, config.output_vocab_size, bias=False)
951
+ for _ in range(config.n_codes_given, config.n_codes_total)
952
+ ]
953
+ )
954
+ self.gradient_checkpointing = False
955
+ self.n_codes_total = config.n_codes_total
956
+
957
+ # Initialize weights and apply final processing
958
+ self.post_init()
959
+
960
+ def get_input_embeddings(self):
961
+ # one embedding layers for each codebook
962
+ return self.input_embeds_layers
963
+
964
+ def set_input_embeddings(self, new_embeddings):
965
+ # one embedding layers for each codebook
966
+ self.input_embeds_layers = new_embeddings
967
+
968
+ def get_output_embeddings(self):
969
+ # one lm_head for each codebook
970
+ return self.lm_heads
971
+
972
+ def set_output_embeddings(self, new_output_embeddings):
973
+ # one lm_head for each codebook
974
+ self.lm_heads = new_output_embeddings
975
+
976
+ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
977
+ old_embeddings_list = self.get_input_embeddings()
978
+ new_embeddings_list = nn.ModuleList(
979
+ [
980
+ self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing)
981
+ for old_embeddings in old_embeddings_list
982
+ ]
983
+ )
984
+ self.set_input_embeddings(new_embeddings_list)
985
+ new_num_tokens = new_embeddings_list[0].weight.shape[0]
986
+
987
+ # if word embeddings are not tied, make sure that lm head is resized as well
988
+ if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
989
+ old_lm_head_list = self.get_output_embeddings()
990
+ new_lm_head_list = nn.ModuleList(
991
+ [self._get_resized_lm_head(old_lm_head, new_num_tokens) for old_lm_head in old_lm_head_list]
992
+ )
993
+ self.set_output_embeddings(new_lm_head_list)
994
+
995
+ return self.get_input_embeddings()
996
+
997
+ def resize_token_embeddings(
998
+ self,
999
+ new_num_tokens: Optional[int] = None,
1000
+ pad_to_multiple_of: Optional[int] = None,
1001
+ mean_resizing: bool = True,
1002
+ ) -> nn.Embedding:
1003
+ """
1004
+ Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
1005
+
1006
+ Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
1007
+
1008
+ Arguments:
1009
+ new_num_tokens (`int`, *optional*):
1010
+ The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
1011
+ vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
1012
+ returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
1013
+ pad_to_multiple_of (`int`, *optional*):
1014
+ If set will pad the embedding matrix to a multiple of the provided value.
1015
+
1016
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
1017
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
1018
+ details about this, or help on choosing the correct value for resizing, refer to this guide:
1019
+ https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
1020
+ mean_resizing (`bool`):
1021
+ Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
1022
+ covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
1023
+
1024
+ Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
1025
+ where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the
1026
+ old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
1027
+ Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
1028
+
1029
+ Return:
1030
+ `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
1031
+ """
1032
+ model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
1033
+ if new_num_tokens is None and pad_to_multiple_of is None:
1034
+ return model_embeds
1035
+
1036
+ # Update base model and current model config
1037
+ self.config.output_vocab_size = model_embeds[0].weight.shape[0]
1038
+ self.config.vocab_size = model_embeds[0].weight.shape[0]
1039
+ self.output_vocab_size = model_embeds[0].weight.shape[0]
1040
+ self.vocab_size = model_embeds[0].weight.shape[0]
1041
+
1042
+ # Tie weights again if needed
1043
+ self.tie_weights()
1044
+
1045
+ return model_embeds
1046
+
1047
+ def _tie_weights(self):
1048
+ if getattr(self.config, "tie_word_embeddings", True):
1049
+ self._tied_weights_keys = []
1050
+ output_embeddings = self.get_output_embeddings()
1051
+ input_embeddings = self.get_input_embeddings()
1052
+
1053
+ for i in range(self.config.n_codes_total - self.config.n_codes_given):
1054
+ # self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight
1055
+ self._tie_or_clone_weights(output_embeddings[i], input_embeddings[i + 1])
1056
+ self._tied_weights_keys.append(f"lm_heads.{i}.weight")
1057
+
1058
+ def tie_weights(self):
1059
+ """
1060
+ Tie the weights between the input embeddings list and the output embeddings list.
1061
+
1062
+ If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
1063
+ weights instead.
1064
+ """
1065
+ for module in self.modules():
1066
+ if hasattr(module, "_tie_weights"):
1067
+ module._tie_weights()
1068
+
1069
+ @auto_docstring
1070
+ def forward(
1071
+ self,
1072
+ codebook_idx: int, # an additional idx corresponding to the id of the codebook that will be predicted
1073
+ input_ids: Optional[torch.Tensor] = None,
1074
+ attention_mask: Optional[torch.Tensor] = None,
1075
+ position_ids: Optional[torch.Tensor] = None,
1076
+ head_mask: Optional[torch.Tensor] = None,
1077
+ labels: Optional[torch.LongTensor] = None,
1078
+ input_embeds: Optional[torch.Tensor] = None,
1079
+ output_attentions: Optional[bool] = None,
1080
+ output_hidden_states: Optional[bool] = None,
1081
+ return_dict: Optional[bool] = None,
1082
+ ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
1083
+ r"""
1084
+ codebook_idx (`int`):
1085
+ Index of the codebook that will be predicted.
1086
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1087
+ NOT IMPLEMENTED YET.
1088
+ input_embeds (`torch.FloatTensor` of shape `(batch_size, input_sequence_length, hidden_size)`, *optional*):
1089
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. If
1090
+ `past_key_values` is used, optionally only the last `input_embeds` have to be input (see
1091
+ `past_key_values`). This is useful if you want more control over how to convert `input_ids` indices into
1092
+ associated vectors than the model's internal embedding lookup matrix.
1093
+ """
1094
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1095
+ output_hidden_states = (
1096
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1097
+ )
1098
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1099
+
1100
+ loss = None
1101
+ if labels is not None:
1102
+ raise NotImplementedError("Training is not implemented yet")
1103
+
1104
+ if codebook_idx == 0:
1105
+ raise ValueError("Cannot predict 0th codebook - 0th codebook should be predicted by the coarse model")
1106
+
1107
+ if input_ids is not None and input_embeds is not None:
1108
+ raise ValueError("You cannot specify both input_ids and input_embeds at the same time")
1109
+
1110
+ if input_ids is None and input_embeds is None:
1111
+ raise ValueError("You have to specify either input_ids or input_embeds")
1112
+
1113
+ if input_ids is not None:
1114
+ # the input_embeddings are the sum of the j previous codebooks embeddings before
1115
+ # the current codebook_idx codebook
1116
+
1117
+ # forward the GPT model itself
1118
+ input_embeds = [
1119
+ input_embeds_layer(input_ids[:, :, i]).unsqueeze(-1)
1120
+ for i, input_embeds_layer in enumerate(self.input_embeds_layers)
1121
+ ] # token embeddings of shape (b, t, n_embd)
1122
+ input_embeds = torch.cat(input_embeds, dim=-1)
1123
+ input_embeds = input_embeds[:, :, :, : codebook_idx + 1].sum(dim=-1)
1124
+
1125
+ input_shape = input_embeds.size()[:-1]
1126
+ batch_size = input_embeds.shape[0]
1127
+ seq_length = input_shape[1]
1128
+
1129
+ device = input_ids.device if input_ids is not None else input_embeds.device
1130
+
1131
+ if position_ids is None:
1132
+ position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
1133
+ position_ids = position_ids.unsqueeze(0) # shape (1, seq_length)
1134
+
1135
+ position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd)
1136
+
1137
+ # Attention mask.
1138
+ if attention_mask is not None:
1139
+ if batch_size <= 0:
1140
+ raise ValueError("batch_size has to be defined and > 0")
1141
+ if self.config._attn_implementation == "flash_attention_2":
1142
+ attention_mask = attention_mask if 0 in attention_mask else None
1143
+ else:
1144
+ # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
1145
+ # from_seq_length is 1 to easily broadcast
1146
+ attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
1147
+
1148
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
1149
+
1150
+ hidden_states = self.drop(input_embeds + position_embeds)
1151
+ output_shape = input_shape + (hidden_states.size(-1),)
1152
+
1153
+ all_self_attentions = () if output_attentions else None
1154
+ all_hidden_states = () if output_hidden_states else None
1155
+
1156
+ for i, block in enumerate(self.layers):
1157
+ if output_hidden_states:
1158
+ all_hidden_states = all_hidden_states + (hidden_states,)
1159
+
1160
+ outputs = block(
1161
+ hidden_states,
1162
+ attention_mask=attention_mask,
1163
+ head_mask=head_mask[i],
1164
+ output_attentions=output_attentions,
1165
+ )
1166
+
1167
+ hidden_states = outputs[0]
1168
+
1169
+ if output_attentions:
1170
+ all_self_attentions = all_self_attentions + (outputs[1],)
1171
+
1172
+ hidden_states = self.layernorm_final(hidden_states)
1173
+ hidden_states = hidden_states.view(output_shape)
1174
+
1175
+ # Add last hidden state
1176
+ if output_hidden_states:
1177
+ all_hidden_states = all_hidden_states + (hidden_states,)
1178
+
1179
+ logits = self.lm_heads[codebook_idx - self.config.n_codes_given](hidden_states)
1180
+
1181
+ if not return_dict:
1182
+ return tuple(v for v in [None, logits, all_hidden_states, all_self_attentions] if v is not None)
1183
+
1184
+ return MaskedLMOutput(
1185
+ loss=loss,
1186
+ logits=logits,
1187
+ hidden_states=all_hidden_states,
1188
+ attentions=all_self_attentions,
1189
+ )
1190
+
1191
+ @torch.no_grad()
1192
+ def generate(
1193
+ self,
1194
+ coarse_output: torch.Tensor,
1195
+ semantic_generation_config: Optional[BarkSemanticGenerationConfig] = None,
1196
+ coarse_generation_config: Optional[BarkCoarseGenerationConfig] = None,
1197
+ fine_generation_config: BarkFineGenerationConfig = None,
1198
+ codebook_size: int = 1024,
1199
+ history_prompt: Optional[dict[str, torch.Tensor]] = None,
1200
+ **kwargs,
1201
+ ) -> torch.LongTensor:
1202
+ """
1203
+ Generates fine acoustics tokens from input coarse acoustics tokens and an additional optional `Bark` speaker
1204
+ prompt.
1205
+
1206
+ Args:
1207
+ coarse_output (`torch.Tensor` of shape (batch_size, seq_len)):
1208
+ Input coarse acoustics ids, i.e the output of `BarkCoarseModel.generate`.
1209
+ semantic_generation_config (`BarkSemanticGenerationConfig`):
1210
+ Generation config indicating how to generate the semantic tokens.
1211
+ coarse_generation_config (`BarkCoarseGenerationConfig`):
1212
+ Generation config indicating how to generate the coarse tokens.
1213
+ fine_generation_config (`BarkFineGenerationConfig`):
1214
+ Generation config indicating how to generate the fine tokens.
1215
+ codebook_size (`int`, *optional*, defaults to 1024):
1216
+ Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
1217
+ history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
1218
+ Optional `Bark` speaker prompt.
1219
+ Returns:
1220
+ torch.LongTensor: Output fine acoustics tokens.
1221
+ """
1222
+ if semantic_generation_config is None:
1223
+ raise ValueError("`semantic_generation_config` has to be provided")
1224
+
1225
+ if coarse_generation_config is None:
1226
+ raise ValueError("`coarse_generation_config` has to be provided")
1227
+
1228
+ if fine_generation_config is None:
1229
+ raise ValueError("`fine_generation_config` has to be provided")
1230
+
1231
+ # since we don't really use GenerationConfig through the fine model (autoencoder)
1232
+ # and since only temperature is used from the classic GenerationConfig parameters
1233
+ # manually impose the kwargs priority over the generation config
1234
+ temperature = kwargs.get("temperature", fine_generation_config.temperature)
1235
+
1236
+ max_fine_history_length = fine_generation_config.max_fine_history_length
1237
+ max_fine_input_length = fine_generation_config.max_fine_input_length
1238
+
1239
+ # shape: (batch, n_coarse_codebooks * seq_len)
1240
+ # new_shape: (batch, seq_len, n_coarse_codebooks)
1241
+ coarse_output = coarse_output.view(coarse_output.shape[0], -1, coarse_generation_config.n_coarse_codebooks)
1242
+
1243
+ # brings ids into the range [0, codebook_size -1]
1244
+ coarse_output = torch.remainder(coarse_output - semantic_generation_config.semantic_vocab_size, codebook_size)
1245
+ batch_size = coarse_output.shape[0]
1246
+
1247
+ if history_prompt is not None:
1248
+ x_fine_history = torch.repeat_interleave(history_prompt["fine_prompt"].T[None], batch_size, dim=0)
1249
+ # transpose to get to shape (seq_len, n_fine_codebooks)
1250
+ else:
1251
+ x_fine_history = None
1252
+
1253
+ n_coarse = coarse_generation_config.n_coarse_codebooks
1254
+
1255
+ # pad the last 6th codebooks
1256
+ fine_input = F.pad(
1257
+ coarse_output,
1258
+ (0, fine_generation_config.n_fine_codebooks - n_coarse),
1259
+ "constant",
1260
+ codebook_size,
1261
+ )
1262
+
1263
+ # prepend history if available (max max_fine_history_length)
1264
+ if x_fine_history is not None:
1265
+ fine_input = torch.cat([x_fine_history[:, -max_fine_history_length:, :], fine_input], dim=1)
1266
+
1267
+ # len of the fine_history that has been added to fine_input
1268
+ n_history = x_fine_history[:, -max_fine_history_length:, :].shape[1]
1269
+ else:
1270
+ n_history = 0
1271
+
1272
+ n_remove_from_end = 0
1273
+ # need to pad if too short (since non-causal model)
1274
+ if fine_input.shape[1] < max_fine_input_length:
1275
+ n_remove_from_end = max_fine_input_length - fine_input.shape[1]
1276
+ fine_input = F.pad(fine_input, (0, 0, 0, n_remove_from_end), mode="constant", value=codebook_size)
1277
+
1278
+ # we can be lazy about fractional loop and just keep overwriting codebooks.
1279
+ # seems that coarse_output.shape[1] - (max_fine_input_length - n_history) is equal to minus n_remove_from_end
1280
+ # So if we needed to pad because too short, n_loops is always 1 (because n_remove_from_end > 0)
1281
+ # If not, we loop over at least twice.
1282
+
1283
+ n_loops = (coarse_output.shape[1] - (max_fine_input_length - n_history)) / max_fine_history_length
1284
+ n_loops = int(np.ceil(n_loops))
1285
+ n_loops = max(0, n_loops) + 1
1286
+
1287
+ for n_outer in range(n_loops):
1288
+ start_idx = min([n_outer * max_fine_history_length, fine_input.shape[1] - max_fine_input_length])
1289
+
1290
+ start_fill_idx = min(
1291
+ [n_history + n_outer * max_fine_history_length, fine_input.shape[1] - max_fine_history_length]
1292
+ )
1293
+ rel_start_fill_idx = start_fill_idx - start_idx
1294
+ input_buffer = fine_input[:, start_idx : start_idx + max_fine_input_length, :]
1295
+ for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks):
1296
+ logits = self.forward(n_inner, input_buffer).logits
1297
+ if temperature is None or temperature == 1.0:
1298
+ relevant_logits = logits[:, rel_start_fill_idx:, :codebook_size]
1299
+ codebook_preds = torch.argmax(relevant_logits, -1)
1300
+ else:
1301
+ relevant_logits = logits[:, :, :codebook_size] / temperature
1302
+ # apply softmax
1303
+ probs = F.softmax(relevant_logits, dim=-1)[:, rel_start_fill_idx:max_fine_input_length]
1304
+ # reshape to 2D: (batch_size, seq_len, codebook_size) -> (batch_size*seq_len, codebook_size)
1305
+ probs = probs.reshape((-1, codebook_size))
1306
+ # multinomial then reshape : (batch_size*seq_len)-> (batch_size,seq_len)
1307
+ codebook_preds = torch.multinomial(probs, num_samples=1).view(batch_size, -1)
1308
+ codebook_preds = codebook_preds.to(torch.int32)
1309
+ input_buffer[:, rel_start_fill_idx:, n_inner] = codebook_preds
1310
+ del logits, codebook_preds
1311
+
1312
+ # transfer into fine_input
1313
+ for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks):
1314
+ fine_input[
1315
+ :, start_fill_idx : start_fill_idx + (max_fine_input_length - rel_start_fill_idx), n_inner
1316
+ ] = input_buffer[:, rel_start_fill_idx:, n_inner]
1317
+ del input_buffer
1318
+
1319
+ fine_input = fine_input.transpose(1, 2)[:, :, n_history:]
1320
+ if n_remove_from_end > 0:
1321
+ fine_input = fine_input[:, :, :-n_remove_from_end]
1322
+
1323
+ if fine_input.shape[-1] != coarse_output.shape[-2]:
1324
+ raise ValueError("input and output should have the same seq_len")
1325
+
1326
+ return fine_input
1327
+
1328
+
1329
+ @auto_docstring(
1330
+ custom_intro="""
1331
+ The full Bark model, a text-to-speech model composed of 4 sub-models:
1332
+ - [`BarkSemanticModel`] (also referred to as the 'text' model): a causal auto-regressive transformer model that
1333
+ takes
1334
+ as input tokenized text, and predicts semantic text tokens that capture the meaning of the text.
1335
+ - [`BarkCoarseModel`] (also referred to as the 'coarse acoustics' model), also a causal autoregressive transformer,
1336
+ that takes into input the results of the last model. It aims at regressing the first two audio codebooks necessary
1337
+ to `encodec`.
1338
+ - [`BarkFineModel`] (the 'fine acoustics' model), this time a non-causal autoencoder transformer, which iteratively
1339
+ predicts the last codebooks based on the sum of the previous codebooks embeddings.
1340
+ - having predicted all the codebook channels from the [`EncodecModel`], Bark uses it to decode the output audio
1341
+ array.
1342
+
1343
+ It should be noted that each of the first three modules can support conditional speaker embeddings to condition the
1344
+ output sound according to specific predefined voice.
1345
+ """
1346
+ )
1347
+ class BarkModel(BarkPreTrainedModel):
1348
+ config: BarkConfig
1349
+
1350
+ def __init__(self, config):
1351
+ super().__init__(config)
1352
+
1353
+ self.semantic = BarkSemanticModel(config.semantic_config)
1354
+ self.coarse_acoustics = BarkCoarseModel(config.coarse_acoustics_config)
1355
+ self.fine_acoustics = BarkFineModel(config.fine_acoustics_config)
1356
+
1357
+ self.codec_model = AutoModel.from_config(config.codec_config)
1358
+
1359
+ self.config = config
1360
+
1361
+ @classmethod
1362
+ def can_generate(cls) -> bool:
1363
+ # Bark has a unique model structure, where the external class (`BarkModel`) doesn't need to inherit from
1364
+ # `GenerationMixin` (it has a non-standard generation method), but one of the internal models do
1365
+ # (`BarkSemanticModel`). This means that the base `can_generate()` will return `False`, but we need to
1366
+ # override it so as to do `GenerationConfig` handling in multiple parts of the codebase.
1367
+ return True
1368
+
1369
+ @property
1370
+ def device(self) -> torch.device:
1371
+ """
1372
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
1373
+ device).
1374
+ """
1375
+ # for bark_model, device must be verified on its sub-models
1376
+ # if has _hf_hook, has been offloaded so the device has to be found in the hook
1377
+ if not hasattr(self.semantic, "_hf_hook"):
1378
+ return get_parameter_device(self)
1379
+ for module in self.semantic.modules():
1380
+ if (
1381
+ hasattr(module, "_hf_hook")
1382
+ and hasattr(module._hf_hook, "execution_device")
1383
+ and module._hf_hook.execution_device is not None
1384
+ ):
1385
+ return torch.device(module._hf_hook.execution_device)
1386
+
1387
+ def enable_cpu_offload(
1388
+ self,
1389
+ accelerator_id: Optional[int] = 0,
1390
+ **kwargs,
1391
+ ):
1392
+ r"""
1393
+ Offloads all sub-models to CPU using accelerate, reducing memory usage with a low impact on performance. This
1394
+ method moves one whole sub-model at a time to the accelerator when it is used, and the sub-model remains in accelerator until the next sub-model runs.
1395
+
1396
+ Args:
1397
+ accelerator_id (`int`, *optional*, defaults to 0):
1398
+ accelerator id on which the sub-models will be loaded and offloaded. This argument is deprecated.
1399
+ kwargs (`dict`, *optional*):
1400
+ additional keyword arguments:
1401
+ `gpu_id`: accelerator id on which the sub-models will be loaded and offloaded.
1402
+ """
1403
+ if is_accelerate_available():
1404
+ from accelerate import cpu_offload_with_hook
1405
+ else:
1406
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate`.")
1407
+
1408
+ gpu_id = kwargs.get("gpu_id", 0)
1409
+
1410
+ if gpu_id != 0:
1411
+ warnings.warn(
1412
+ "The argument `gpu_id` is deprecated and will be removed in version 4.54.0 of Transformers. Please use `accelerator_id` instead.",
1413
+ FutureWarning,
1414
+ )
1415
+ accelerator_id = gpu_id
1416
+
1417
+ device_type = "cuda"
1418
+ if is_torch_accelerator_available():
1419
+ device_type = torch.accelerator.current_accelerator().type
1420
+ device = torch.device(f"{device_type}:{accelerator_id}")
1421
+
1422
+ torch_accelerator_module = getattr(torch, device_type)
1423
+ if self.device.type != "cpu":
1424
+ self.to("cpu")
1425
+ torch_accelerator_module.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
1426
+
1427
+ # this layer is used outside the first forward pass of semantic so need to be loaded before semantic
1428
+ self.semantic.input_embeds_layer, _ = cpu_offload_with_hook(self.semantic.input_embeds_layer, device)
1429
+
1430
+ hook = None
1431
+ for cpu_offloaded_model in [
1432
+ self.semantic,
1433
+ self.coarse_acoustics,
1434
+ self.fine_acoustics,
1435
+ ]:
1436
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
1437
+
1438
+ self.fine_acoustics_hook = hook
1439
+
1440
+ _, hook = cpu_offload_with_hook(self.codec_model, device, prev_module_hook=hook)
1441
+
1442
+ # We'll offload the last model manually.
1443
+ self.codec_model_hook = hook
1444
+
1445
+ def codec_decode(self, fine_output, output_lengths=None):
1446
+ """Turn quantized audio codes into audio array using encodec."""
1447
+
1448
+ fine_output = fine_output.transpose(0, 1)
1449
+ emb = self.codec_model.quantizer.decode(fine_output)
1450
+
1451
+ if output_lengths is not None:
1452
+ # encodec uses LSTMs which behaves differently with appended padding
1453
+ # decoding with encodec takes around 0.1% of the total generation time
1454
+ # to keep generation quality, we break batching
1455
+ out = [sample[:, :l].unsqueeze(0) for (sample, l) in zip(emb, output_lengths)]
1456
+ audio_arr = [self.codec_model.decoder(sample).squeeze() for sample in out]
1457
+ else:
1458
+ out = self.codec_model.decoder(emb)
1459
+ audio_arr = out.squeeze(1) # squeeze the codebook dimension
1460
+
1461
+ return audio_arr
1462
+
1463
+ @torch.no_grad()
1464
+ def generate(
1465
+ self,
1466
+ input_ids: Optional[torch.Tensor] = None,
1467
+ history_prompt: Optional[dict[str, torch.Tensor]] = None,
1468
+ return_output_lengths: Optional[bool] = None,
1469
+ **kwargs,
1470
+ ) -> torch.LongTensor:
1471
+ """
1472
+ Generates audio from an input prompt and an additional optional `Bark` speaker prompt.
1473
+
1474
+ Args:
1475
+ input_ids (`Optional[torch.Tensor]` of shape (batch_size, seq_len), *optional*):
1476
+ Input ids. Will be truncated up to 256 tokens. Note that the output audios will be as long as the
1477
+ longest generation among the batch.
1478
+ history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
1479
+ Optional `Bark` speaker prompt. Note that for now, this model takes only one speaker prompt per batch.
1480
+ kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments are of two types:
1481
+
1482
+ - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model.
1483
+ - With a *semantic_*, *coarse_*, *fine_* prefix, they will be input for the `generate` method of the
1484
+ semantic, coarse and fine respectively. It has the priority over the keywords without a prefix.
1485
+
1486
+ This means you can, for example, specify a generation strategy for all sub-models except one.
1487
+ return_output_lengths (`bool`, *optional*):
1488
+ Whether or not to return the waveform lengths. Useful when batching.
1489
+ Returns:
1490
+ By default:
1491
+ - **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
1492
+ When `return_output_lengths=True`:
1493
+ Returns a tuple made of:
1494
+ - **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
1495
+ - **output_lengths** (`torch.Tensor` of shape (batch_size)): The length of each waveform in the batch
1496
+ Example:
1497
+
1498
+ ```python
1499
+ >>> from transformers import AutoProcessor, BarkModel
1500
+
1501
+ >>> processor = AutoProcessor.from_pretrained("suno/bark-small")
1502
+ >>> model = BarkModel.from_pretrained("suno/bark-small")
1503
+
1504
+ >>> # To add a voice preset, you can pass `voice_preset` to `BarkProcessor.__call__(...)`
1505
+ >>> voice_preset = "v2/en_speaker_6"
1506
+
1507
+ >>> inputs = processor("Hello, my dog is cute, I need him in my life", voice_preset=voice_preset)
1508
+
1509
+ >>> audio_array = model.generate(**inputs, semantic_max_new_tokens=100)
1510
+ >>> audio_array = audio_array.cpu().numpy().squeeze()
1511
+ ```
1512
+ """
1513
+ # TODO (joao):workaround until nested generation config is compatible with PreTrained Model
1514
+ # todo: dict
1515
+ semantic_generation_config = BarkSemanticGenerationConfig(**self.generation_config.semantic_config)
1516
+ coarse_generation_config = BarkCoarseGenerationConfig(**self.generation_config.coarse_acoustics_config)
1517
+ fine_generation_config = BarkFineGenerationConfig(**self.generation_config.fine_acoustics_config)
1518
+
1519
+ kwargs_semantic = {
1520
+ # if "attention_mask" is set, it should not be passed to CoarseModel and FineModel
1521
+ "attention_mask": kwargs.pop("attention_mask", None),
1522
+ "min_eos_p": kwargs.pop("min_eos_p", None),
1523
+ }
1524
+ kwargs_coarse = {}
1525
+ kwargs_fine = {}
1526
+ for key, value in kwargs.items():
1527
+ if key.startswith("semantic_"):
1528
+ key = key[len("semantic_") :]
1529
+ kwargs_semantic[key] = value
1530
+ elif key.startswith("coarse_"):
1531
+ key = key[len("coarse_") :]
1532
+ kwargs_coarse[key] = value
1533
+ elif key.startswith("fine_"):
1534
+ key = key[len("fine_") :]
1535
+ kwargs_fine[key] = value
1536
+ else:
1537
+ # If the key is already in a specific config, then it's been set with a
1538
+ # submodules specific value and we don't override
1539
+ if key not in kwargs_semantic:
1540
+ kwargs_semantic[key] = value
1541
+ if key not in kwargs_coarse:
1542
+ kwargs_coarse[key] = value
1543
+ if key not in kwargs_fine:
1544
+ kwargs_fine[key] = value
1545
+
1546
+ # 1. Generate from the semantic model
1547
+ if "generation_config" in kwargs_semantic:
1548
+ kwargs_semantic.pop("generation_config")
1549
+ semantic_output = self.semantic.generate(
1550
+ input_ids,
1551
+ history_prompt=history_prompt,
1552
+ semantic_generation_config=semantic_generation_config,
1553
+ **kwargs_semantic,
1554
+ )
1555
+
1556
+ # 2. Generate from the coarse model
1557
+ if "generation_config" in kwargs_coarse:
1558
+ kwargs_coarse.pop("generation_config")
1559
+ coarse_output = self.coarse_acoustics.generate(
1560
+ semantic_output,
1561
+ history_prompt=history_prompt,
1562
+ semantic_generation_config=semantic_generation_config,
1563
+ coarse_generation_config=coarse_generation_config,
1564
+ codebook_size=self.generation_config.codebook_size,
1565
+ return_output_lengths=return_output_lengths,
1566
+ **kwargs_coarse,
1567
+ )
1568
+
1569
+ output_lengths = None
1570
+ if return_output_lengths:
1571
+ coarse_output, output_lengths = coarse_output
1572
+ # (batch_size, seq_len*coarse_codebooks) -> (batch_size, seq_len)
1573
+ output_lengths = output_lengths // coarse_generation_config.n_coarse_codebooks
1574
+
1575
+ # 3. "generate" from the fine model
1576
+ if "generation_config" in kwargs_fine:
1577
+ kwargs_fine.pop("generation_config")
1578
+ output = self.fine_acoustics.generate(
1579
+ coarse_output,
1580
+ history_prompt=history_prompt,
1581
+ semantic_generation_config=semantic_generation_config,
1582
+ coarse_generation_config=coarse_generation_config,
1583
+ fine_generation_config=fine_generation_config,
1584
+ codebook_size=self.generation_config.codebook_size,
1585
+ **kwargs_fine,
1586
+ )
1587
+
1588
+ if getattr(self, "fine_acoustics_hook", None) is not None:
1589
+ # Manually offload fine_acoustics to CPU
1590
+ # and load codec_model to GPU
1591
+ # since bark doesn't use codec_model forward pass
1592
+ self.fine_acoustics_hook.offload()
1593
+ self.codec_model = self.codec_model.to(self.device)
1594
+
1595
+ # 4. Decode the output and generate audio array
1596
+ audio = self.codec_decode(output, output_lengths)
1597
+
1598
+ if getattr(self, "codec_model_hook", None) is not None:
1599
+ # Offload codec_model to CPU
1600
+ self.codec_model_hook.offload()
1601
+
1602
+ if return_output_lengths:
1603
+ output_lengths = [len(sample) for sample in audio]
1604
+ audio = nn.utils.rnn.pad_sequence(audio, batch_first=True, padding_value=0)
1605
+ return audio, output_lengths
1606
+
1607
+ return audio
1608
+
1609
+ def tie_weights(self):
1610
+ """
1611
+ Tie the weights between the input embeddings list and the output embeddings list.
1612
+
1613
+ If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
1614
+ weights instead.
1615
+ """
1616
+ for module in self.modules():
1617
+ if hasattr(module, "_tie_weights"):
1618
+ module._tie_weights()
1619
+
1620
+
1621
+ __all__ = [
1622
+ "BarkFineModel",
1623
+ "BarkSemanticModel",
1624
+ "BarkCoarseModel",
1625
+ "BarkModel",
1626
+ "BarkPreTrainedModel",
1627
+ "BarkCausalModel",
1628
+ ]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/processing_bark.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for Bark
17
+ """
18
+
19
+ import json
20
+ import os
21
+ from typing import Optional
22
+
23
+ import numpy as np
24
+
25
+ from ...feature_extraction_utils import BatchFeature
26
+ from ...processing_utils import ProcessorMixin
27
+ from ...tokenization_utils_base import BatchEncoding
28
+ from ...utils import logging
29
+ from ...utils.hub import cached_file
30
+ from ..auto import AutoTokenizer
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class BarkProcessor(ProcessorMixin):
37
+ r"""
38
+ Constructs a Bark processor which wraps a text tokenizer and optional Bark voice presets into a single processor.
39
+
40
+ Args:
41
+ tokenizer ([`PreTrainedTokenizer`]):
42
+ An instance of [`PreTrainedTokenizer`].
43
+ speaker_embeddings (`dict[dict[str]]`, *optional*):
44
+ Optional nested speaker embeddings dictionary. The first level contains voice preset names (e.g
45
+ `"en_speaker_4"`). The second level contains `"semantic_prompt"`, `"coarse_prompt"` and `"fine_prompt"`
46
+ embeddings. The values correspond to the path of the corresponding `np.ndarray`. See
47
+ [here](https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c) for
48
+ a list of `voice_preset_names`.
49
+
50
+ """
51
+
52
+ tokenizer_class = "AutoTokenizer"
53
+ attributes = ["tokenizer"]
54
+
55
+ preset_shape = {
56
+ "semantic_prompt": 1, # 1D array of shape (X,)
57
+ "coarse_prompt": 2, # 2D array of shape (2,X)
58
+ "fine_prompt": 2, # 2D array of shape (8,X)
59
+ }
60
+
61
+ def __init__(self, tokenizer, speaker_embeddings=None):
62
+ super().__init__(tokenizer)
63
+
64
+ self.speaker_embeddings = speaker_embeddings
65
+
66
+ @classmethod
67
+ def from_pretrained(
68
+ cls, pretrained_processor_name_or_path, speaker_embeddings_dict_path="speaker_embeddings_path.json", **kwargs
69
+ ):
70
+ r"""
71
+ Instantiate a Bark processor associated with a pretrained model.
72
+
73
+ Args:
74
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
75
+ This can be either:
76
+
77
+ - a string, the *model id* of a pretrained [`BarkProcessor`] hosted inside a model repo on
78
+ huggingface.co.
79
+ - a path to a *directory* containing a processor saved using the [`~BarkProcessor.save_pretrained`]
80
+ method, e.g., `./my_model_directory/`.
81
+ speaker_embeddings_dict_path (`str`, *optional*, defaults to `"speaker_embeddings_path.json"`):
82
+ The name of the `.json` file containing the speaker_embeddings dictionary located in
83
+ `pretrained_model_name_or_path`. If `None`, no speaker_embeddings is loaded.
84
+ **kwargs
85
+ Additional keyword arguments passed along to both
86
+ [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`].
87
+ """
88
+
89
+ if speaker_embeddings_dict_path is not None:
90
+ speaker_embeddings_path = cached_file(
91
+ pretrained_processor_name_or_path,
92
+ speaker_embeddings_dict_path,
93
+ subfolder=kwargs.pop("subfolder", None),
94
+ cache_dir=kwargs.pop("cache_dir", None),
95
+ force_download=kwargs.pop("force_download", False),
96
+ proxies=kwargs.pop("proxies", None),
97
+ resume_download=kwargs.pop("resume_download", None),
98
+ local_files_only=kwargs.pop("local_files_only", False),
99
+ token=kwargs.pop("use_auth_token", None),
100
+ revision=kwargs.pop("revision", None),
101
+ _raise_exceptions_for_gated_repo=False,
102
+ _raise_exceptions_for_missing_entries=False,
103
+ _raise_exceptions_for_connection_errors=False,
104
+ )
105
+ if speaker_embeddings_path is None:
106
+ logger.warning(
107
+ f"""`{os.path.join(pretrained_processor_name_or_path, speaker_embeddings_dict_path)}` does not exists
108
+ , no preloaded speaker embeddings will be used - Make sure to provide a correct path to the json
109
+ dictionary if wanted, otherwise set `speaker_embeddings_dict_path=None`."""
110
+ )
111
+ speaker_embeddings = None
112
+ else:
113
+ with open(speaker_embeddings_path) as speaker_embeddings_json:
114
+ speaker_embeddings = json.load(speaker_embeddings_json)
115
+ else:
116
+ speaker_embeddings = None
117
+
118
+ if speaker_embeddings is not None:
119
+ if "repo_or_path" in speaker_embeddings:
120
+ speaker_embeddings["repo_or_path"] = pretrained_processor_name_or_path
121
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_processor_name_or_path, **kwargs)
122
+
123
+ return cls(tokenizer=tokenizer, speaker_embeddings=speaker_embeddings)
124
+
125
+ def save_pretrained(
126
+ self,
127
+ save_directory,
128
+ speaker_embeddings_dict_path="speaker_embeddings_path.json",
129
+ speaker_embeddings_directory="speaker_embeddings",
130
+ push_to_hub: bool = False,
131
+ **kwargs,
132
+ ):
133
+ """
134
+ Saves the attributes of this processor (tokenizer...) in the specified directory so that it can be reloaded
135
+ using the [`~BarkProcessor.from_pretrained`] method.
136
+
137
+ Args:
138
+ save_directory (`str` or `os.PathLike`):
139
+ Directory where the tokenizer files and the speaker embeddings will be saved (directory will be created
140
+ if it does not exist).
141
+ speaker_embeddings_dict_path (`str`, *optional*, defaults to `"speaker_embeddings_path.json"`):
142
+ The name of the `.json` file that will contains the speaker_embeddings nested path dictionary, if it
143
+ exists, and that will be located in `pretrained_model_name_or_path/speaker_embeddings_directory`.
144
+ speaker_embeddings_directory (`str`, *optional*, defaults to `"speaker_embeddings/"`):
145
+ The name of the folder in which the speaker_embeddings arrays will be saved.
146
+ push_to_hub (`bool`, *optional*, defaults to `False`):
147
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
148
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
149
+ namespace).
150
+ kwargs:
151
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
152
+ """
153
+ if self.speaker_embeddings is not None:
154
+ os.makedirs(os.path.join(save_directory, speaker_embeddings_directory, "v2"), exist_ok=True)
155
+
156
+ embeddings_dict = {}
157
+
158
+ embeddings_dict["repo_or_path"] = save_directory
159
+
160
+ for prompt_key in self.available_voice_presets:
161
+ voice_preset = self._load_voice_preset(prompt_key)
162
+
163
+ tmp_dict = {}
164
+ for key in self.speaker_embeddings[prompt_key]:
165
+ np.save(
166
+ os.path.join(
167
+ embeddings_dict["repo_or_path"], speaker_embeddings_directory, f"{prompt_key}_{key}"
168
+ ),
169
+ voice_preset[key],
170
+ allow_pickle=False,
171
+ )
172
+ tmp_dict[key] = os.path.join(speaker_embeddings_directory, f"{prompt_key}_{key}.npy")
173
+
174
+ embeddings_dict[prompt_key] = tmp_dict
175
+
176
+ with open(os.path.join(save_directory, speaker_embeddings_dict_path), "w") as fp:
177
+ json.dump(embeddings_dict, fp)
178
+
179
+ super().save_pretrained(save_directory, push_to_hub, **kwargs)
180
+
181
+ def _load_voice_preset(self, voice_preset: Optional[str] = None, **kwargs):
182
+ voice_preset_paths = self.speaker_embeddings[voice_preset]
183
+
184
+ voice_preset_dict = {}
185
+ for key in ["semantic_prompt", "coarse_prompt", "fine_prompt"]:
186
+ if key not in voice_preset_paths:
187
+ raise ValueError(
188
+ f"Voice preset unrecognized, missing {key} as a key in self.speaker_embeddings[{voice_preset}]."
189
+ )
190
+
191
+ path = cached_file(
192
+ self.speaker_embeddings.get("repo_or_path", "/"),
193
+ voice_preset_paths[key],
194
+ subfolder=kwargs.pop("subfolder", None),
195
+ cache_dir=kwargs.pop("cache_dir", None),
196
+ force_download=kwargs.pop("force_download", False),
197
+ proxies=kwargs.pop("proxies", None),
198
+ resume_download=kwargs.pop("resume_download", None),
199
+ local_files_only=kwargs.pop("local_files_only", False),
200
+ token=kwargs.pop("use_auth_token", None),
201
+ revision=kwargs.pop("revision", None),
202
+ _raise_exceptions_for_gated_repo=False,
203
+ _raise_exceptions_for_missing_entries=False,
204
+ _raise_exceptions_for_connection_errors=False,
205
+ )
206
+ if path is None:
207
+ raise ValueError(
208
+ f"""`{os.path.join(self.speaker_embeddings.get("repo_or_path", "/"), voice_preset_paths[key])}` does not exists
209
+ , no preloaded voice preset will be used - Make sure to provide correct paths to the {voice_preset}
210
+ embeddings."""
211
+ )
212
+
213
+ voice_preset_dict[key] = np.load(path)
214
+
215
+ return voice_preset_dict
216
+
217
+ def _validate_voice_preset_dict(self, voice_preset: Optional[dict] = None):
218
+ for key in ["semantic_prompt", "coarse_prompt", "fine_prompt"]:
219
+ if key not in voice_preset:
220
+ raise ValueError(f"Voice preset unrecognized, missing {key} as a key.")
221
+
222
+ if not isinstance(voice_preset[key], np.ndarray):
223
+ raise TypeError(f"{key} voice preset must be a {str(self.preset_shape[key])}D ndarray.")
224
+
225
+ if len(voice_preset[key].shape) != self.preset_shape[key]:
226
+ raise ValueError(f"{key} voice preset must be a {str(self.preset_shape[key])}D ndarray.")
227
+
228
+ @property
229
+ def available_voice_presets(self) -> list:
230
+ """
231
+ Returns a list of available voice presets.
232
+
233
+ Returns:
234
+ `list[str]`: A list of voice preset names.
235
+ """
236
+ if self.speaker_embeddings is None:
237
+ return []
238
+
239
+ voice_presets = list(self.speaker_embeddings.keys())
240
+ if "repo_or_path" in voice_presets:
241
+ voice_presets.remove("repo_or_path")
242
+ return voice_presets
243
+
244
+ def _verify_speaker_embeddings(self, remove_unavailable: bool = True):
245
+ # check which actually downloaded properly / are available
246
+ unavailable_keys = []
247
+ if self.speaker_embeddings is not None:
248
+ for voice_preset in self.available_voice_presets:
249
+ try:
250
+ voice_preset_dict = self._load_voice_preset(voice_preset)
251
+ except ValueError:
252
+ # error from `_load_voice_preset` of path not existing
253
+ unavailable_keys.append(voice_preset)
254
+ continue
255
+ self._validate_voice_preset_dict(voice_preset_dict)
256
+
257
+ if unavailable_keys:
258
+ logger.warning(
259
+ f"The following {len(unavailable_keys)} speaker embeddings are not available: {unavailable_keys} "
260
+ "If you would like to use them, please check the paths or try downloading them again."
261
+ )
262
+
263
+ if remove_unavailable:
264
+ for voice_preset in unavailable_keys:
265
+ del self.speaker_embeddings[voice_preset]
266
+
267
+ def __call__(
268
+ self,
269
+ text=None,
270
+ voice_preset=None,
271
+ return_tensors="pt",
272
+ max_length=256,
273
+ add_special_tokens=False,
274
+ return_attention_mask=True,
275
+ return_token_type_ids=False,
276
+ **kwargs,
277
+ ) -> BatchEncoding:
278
+ """
279
+ Main method to prepare for the model one or several sequences(s). This method forwards the `text` and `kwargs`
280
+ arguments to the AutoTokenizer's [`~AutoTokenizer.__call__`] to encode the text. The method also proposes a
281
+ voice preset which is a dictionary of arrays that conditions `Bark`'s output. `kwargs` arguments are forwarded
282
+ to the tokenizer and to `cached_file` method if `voice_preset` is a valid filename.
283
+
284
+ Args:
285
+ text (`str`, `list[str]`, `list[list[str]]`):
286
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
287
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
288
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
289
+ voice_preset (`str`, `dict[np.ndarray]`):
290
+ The voice preset, i.e the speaker embeddings. It can either be a valid voice_preset name, e.g
291
+ `"en_speaker_1"`, or directly a dictionary of `np.ndarray` embeddings for each submodel of `Bark`. Or
292
+ it can be a valid file name of a local `.npz` single voice preset containing the keys
293
+ `"semantic_prompt"`, `"coarse_prompt"` and `"fine_prompt"`.
294
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
295
+ If set, will return tensors of a particular framework. Acceptable values are:
296
+
297
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
298
+ - `'np'`: Return NumPy `np.ndarray` objects.
299
+
300
+ Returns:
301
+ [`BatchEncoding`]: A [`BatchEncoding`] object containing the output of the `tokenizer`.
302
+ If a voice preset is provided, the returned object will include a `"history_prompt"` key
303
+ containing a [`BatchFeature`], i.e the voice preset with the right tensors type.
304
+ """
305
+ if voice_preset is not None and not isinstance(voice_preset, dict):
306
+ if (
307
+ isinstance(voice_preset, str)
308
+ and self.speaker_embeddings is not None
309
+ and voice_preset in self.speaker_embeddings
310
+ ):
311
+ voice_preset = self._load_voice_preset(voice_preset)
312
+
313
+ else:
314
+ if isinstance(voice_preset, str) and not voice_preset.endswith(".npz"):
315
+ voice_preset = voice_preset + ".npz"
316
+
317
+ voice_preset = np.load(voice_preset)
318
+
319
+ if voice_preset is not None:
320
+ self._validate_voice_preset_dict(voice_preset, **kwargs)
321
+ voice_preset = BatchFeature(data=voice_preset, tensor_type=return_tensors)
322
+
323
+ encoded_text = self.tokenizer(
324
+ text,
325
+ return_tensors=return_tensors,
326
+ padding="max_length",
327
+ max_length=max_length,
328
+ return_attention_mask=return_attention_mask,
329
+ return_token_type_ids=return_token_type_ids,
330
+ add_special_tokens=add_special_tokens,
331
+ **kwargs,
332
+ )
333
+
334
+ if voice_preset is not None:
335
+ encoded_text["history_prompt"] = voice_preset
336
+
337
+ return encoded_text
338
+
339
+
340
+ __all__ = ["BarkProcessor"]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_bert import *
22
+ from .modeling_bert import *
23
+ from .modeling_flax_bert import *
24
+ from .modeling_tf_bert import *
25
+ from .tokenization_bert import *
26
+ from .tokenization_bert_fast import *
27
+ from .tokenization_bert_tf import *
28
+ else:
29
+ import sys
30
+
31
+ _file = globals()["__file__"]
32
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/configuration_bert.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """BERT model configuration"""
17
+
18
+ from collections import OrderedDict
19
+ from collections.abc import Mapping
20
+
21
+ from ...configuration_utils import PretrainedConfig
22
+ from ...onnx import OnnxConfig
23
+ from ...utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class BertConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to
32
+ instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a
33
+ configuration with the defaults will yield a similar configuration to that of the BERT
34
+ [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture.
35
+
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+
39
+
40
+ Args:
41
+ vocab_size (`int`, *optional*, defaults to 30522):
42
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
43
+ `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`].
44
+ hidden_size (`int`, *optional*, defaults to 768):
45
+ Dimensionality of the encoder layers and the pooler layer.
46
+ num_hidden_layers (`int`, *optional*, defaults to 12):
47
+ Number of hidden layers in the Transformer encoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 12):
49
+ Number of attention heads for each attention layer in the Transformer encoder.
50
+ intermediate_size (`int`, *optional*, defaults to 3072):
51
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
52
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
53
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
54
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
55
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
56
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
57
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
58
+ The dropout ratio for the attention probabilities.
59
+ max_position_embeddings (`int`, *optional*, defaults to 512):
60
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
61
+ just in case (e.g., 512 or 1024 or 2048).
62
+ type_vocab_size (`int`, *optional*, defaults to 2):
63
+ The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`].
64
+ initializer_range (`float`, *optional*, defaults to 0.02):
65
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
66
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
67
+ The epsilon used by the layer normalization layers.
68
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
69
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
70
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
71
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
72
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
73
+ with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
74
+ is_decoder (`bool`, *optional*, defaults to `False`):
75
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
76
+ use_cache (`bool`, *optional*, defaults to `True`):
77
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
78
+ relevant if `config.is_decoder=True`.
79
+ classifier_dropout (`float`, *optional*):
80
+ The dropout ratio for the classification head.
81
+
82
+ Examples:
83
+
84
+ ```python
85
+ >>> from transformers import BertConfig, BertModel
86
+
87
+ >>> # Initializing a BERT google-bert/bert-base-uncased style configuration
88
+ >>> configuration = BertConfig()
89
+
90
+ >>> # Initializing a model (with random weights) from the google-bert/bert-base-uncased style configuration
91
+ >>> model = BertModel(configuration)
92
+
93
+ >>> # Accessing the model configuration
94
+ >>> configuration = model.config
95
+ ```"""
96
+
97
+ model_type = "bert"
98
+
99
+ def __init__(
100
+ self,
101
+ vocab_size=30522,
102
+ hidden_size=768,
103
+ num_hidden_layers=12,
104
+ num_attention_heads=12,
105
+ intermediate_size=3072,
106
+ hidden_act="gelu",
107
+ hidden_dropout_prob=0.1,
108
+ attention_probs_dropout_prob=0.1,
109
+ max_position_embeddings=512,
110
+ type_vocab_size=2,
111
+ initializer_range=0.02,
112
+ layer_norm_eps=1e-12,
113
+ pad_token_id=0,
114
+ position_embedding_type="absolute",
115
+ use_cache=True,
116
+ classifier_dropout=None,
117
+ **kwargs,
118
+ ):
119
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
120
+
121
+ self.vocab_size = vocab_size
122
+ self.hidden_size = hidden_size
123
+ self.num_hidden_layers = num_hidden_layers
124
+ self.num_attention_heads = num_attention_heads
125
+ self.hidden_act = hidden_act
126
+ self.intermediate_size = intermediate_size
127
+ self.hidden_dropout_prob = hidden_dropout_prob
128
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
129
+ self.max_position_embeddings = max_position_embeddings
130
+ self.type_vocab_size = type_vocab_size
131
+ self.initializer_range = initializer_range
132
+ self.layer_norm_eps = layer_norm_eps
133
+ self.position_embedding_type = position_embedding_type
134
+ self.use_cache = use_cache
135
+ self.classifier_dropout = classifier_dropout
136
+
137
+
138
+ class BertOnnxConfig(OnnxConfig):
139
+ @property
140
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
141
+ if self.task == "multiple-choice":
142
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
143
+ else:
144
+ dynamic_axis = {0: "batch", 1: "sequence"}
145
+ return OrderedDict(
146
+ [
147
+ ("input_ids", dynamic_axis),
148
+ ("attention_mask", dynamic_axis),
149
+ ("token_type_ids", dynamic_axis),
150
+ ]
151
+ )
152
+
153
+
154
+ __all__ = ["BertConfig", "BertOnnxConfig"]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py ADDED
@@ -0,0 +1,1801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model."""
17
+
18
+ import math
19
+ import os
20
+ import warnings
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from ...activations import ACT2FN
29
+ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
30
+ from ...generation import GenerationMixin
31
+ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
32
+ from ...modeling_layers import GradientCheckpointingLayer
33
+ from ...modeling_outputs import (
34
+ BaseModelOutputWithPastAndCrossAttentions,
35
+ BaseModelOutputWithPoolingAndCrossAttentions,
36
+ CausalLMOutputWithCrossAttentions,
37
+ MaskedLMOutput,
38
+ MultipleChoiceModelOutput,
39
+ NextSentencePredictorOutput,
40
+ QuestionAnsweringModelOutput,
41
+ SequenceClassifierOutput,
42
+ TokenClassifierOutput,
43
+ )
44
+ from ...modeling_utils import PreTrainedModel
45
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
46
+ from ...utils import ModelOutput, auto_docstring, logging
47
+ from ...utils.deprecation import deprecate_kwarg
48
+ from .configuration_bert import BertConfig
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+
54
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
55
+ """Load tf checkpoints in a pytorch model."""
56
+ try:
57
+ import re
58
+
59
+ import numpy as np
60
+ import tensorflow as tf
61
+ except ImportError:
62
+ logger.error(
63
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
64
+ "https://www.tensorflow.org/install/ for installation instructions."
65
+ )
66
+ raise
67
+ tf_path = os.path.abspath(tf_checkpoint_path)
68
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
69
+ # Load weights from TF model
70
+ init_vars = tf.train.list_variables(tf_path)
71
+ names = []
72
+ arrays = []
73
+ for name, shape in init_vars:
74
+ logger.info(f"Loading TF weight {name} with shape {shape}")
75
+ array = tf.train.load_variable(tf_path, name)
76
+ names.append(name)
77
+ arrays.append(array)
78
+
79
+ for name, array in zip(names, arrays):
80
+ name = name.split("/")
81
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
82
+ # which are not required for using pretrained model
83
+ if any(
84
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
85
+ for n in name
86
+ ):
87
+ logger.info(f"Skipping {'/'.join(name)}")
88
+ continue
89
+ pointer = model
90
+ for m_name in name:
91
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
92
+ scope_names = re.split(r"_(\d+)", m_name)
93
+ else:
94
+ scope_names = [m_name]
95
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
96
+ pointer = getattr(pointer, "weight")
97
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
98
+ pointer = getattr(pointer, "bias")
99
+ elif scope_names[0] == "output_weights":
100
+ pointer = getattr(pointer, "weight")
101
+ elif scope_names[0] == "squad":
102
+ pointer = getattr(pointer, "classifier")
103
+ else:
104
+ try:
105
+ pointer = getattr(pointer, scope_names[0])
106
+ except AttributeError:
107
+ logger.info(f"Skipping {'/'.join(name)}")
108
+ continue
109
+ if len(scope_names) >= 2:
110
+ num = int(scope_names[1])
111
+ pointer = pointer[num]
112
+ if m_name[-11:] == "_embeddings":
113
+ pointer = getattr(pointer, "weight")
114
+ elif m_name == "kernel":
115
+ array = np.transpose(array)
116
+ try:
117
+ if pointer.shape != array.shape:
118
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
119
+ except ValueError as e:
120
+ e.args += (pointer.shape, array.shape)
121
+ raise
122
+ logger.info(f"Initialize PyTorch weight {name}")
123
+ pointer.data = torch.from_numpy(array)
124
+ return model
125
+
126
+
127
+ class BertEmbeddings(nn.Module):
128
+ """Construct the embeddings from word, position and token_type embeddings."""
129
+
130
+ def __init__(self, config):
131
+ super().__init__()
132
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
133
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
134
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
135
+
136
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
137
+ # any TensorFlow checkpoint file
138
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
139
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
140
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
141
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
142
+ self.register_buffer(
143
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
144
+ )
145
+ self.register_buffer(
146
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
147
+ )
148
+
149
+ def forward(
150
+ self,
151
+ input_ids: Optional[torch.LongTensor] = None,
152
+ token_type_ids: Optional[torch.LongTensor] = None,
153
+ position_ids: Optional[torch.LongTensor] = None,
154
+ inputs_embeds: Optional[torch.FloatTensor] = None,
155
+ past_key_values_length: int = 0,
156
+ ) -> torch.Tensor:
157
+ if input_ids is not None:
158
+ input_shape = input_ids.size()
159
+ else:
160
+ input_shape = inputs_embeds.size()[:-1]
161
+
162
+ seq_length = input_shape[1]
163
+
164
+ if position_ids is None:
165
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
166
+
167
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
168
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
169
+ # issue #5664
170
+ if token_type_ids is None:
171
+ if hasattr(self, "token_type_ids"):
172
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
173
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
174
+ token_type_ids = buffered_token_type_ids_expanded
175
+ else:
176
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
177
+
178
+ if inputs_embeds is None:
179
+ inputs_embeds = self.word_embeddings(input_ids)
180
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
181
+
182
+ embeddings = inputs_embeds + token_type_embeddings
183
+ if self.position_embedding_type == "absolute":
184
+ position_embeddings = self.position_embeddings(position_ids)
185
+ embeddings += position_embeddings
186
+ embeddings = self.LayerNorm(embeddings)
187
+ embeddings = self.dropout(embeddings)
188
+ return embeddings
189
+
190
+
191
+ class BertSelfAttention(nn.Module):
192
+ def __init__(self, config, position_embedding_type=None, layer_idx=None):
193
+ super().__init__()
194
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
195
+ raise ValueError(
196
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
197
+ f"heads ({config.num_attention_heads})"
198
+ )
199
+
200
+ self.num_attention_heads = config.num_attention_heads
201
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
202
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
203
+
204
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
205
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
206
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
207
+
208
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
209
+ self.position_embedding_type = position_embedding_type or getattr(
210
+ config, "position_embedding_type", "absolute"
211
+ )
212
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
213
+ self.max_position_embeddings = config.max_position_embeddings
214
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
215
+
216
+ self.is_decoder = config.is_decoder
217
+ self.layer_idx = layer_idx
218
+
219
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
220
+ def forward(
221
+ self,
222
+ hidden_states: torch.Tensor,
223
+ attention_mask: Optional[torch.FloatTensor] = None,
224
+ head_mask: Optional[torch.FloatTensor] = None,
225
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
226
+ past_key_values: Optional[Cache] = None,
227
+ output_attentions: Optional[bool] = False,
228
+ cache_position: Optional[torch.Tensor] = None,
229
+ ) -> tuple[torch.Tensor]:
230
+ batch_size, seq_length, _ = hidden_states.shape
231
+ query_layer = self.query(hidden_states)
232
+ query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
233
+ 1, 2
234
+ )
235
+
236
+ is_updated = False
237
+ is_cross_attention = encoder_hidden_states is not None
238
+ if past_key_values is not None:
239
+ if isinstance(past_key_values, EncoderDecoderCache):
240
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
241
+ if is_cross_attention:
242
+ # after the first generated id, we can subsequently re-use all key/value_layer from cache
243
+ curr_past_key_value = past_key_values.cross_attention_cache
244
+ else:
245
+ curr_past_key_value = past_key_values.self_attention_cache
246
+ else:
247
+ curr_past_key_value = past_key_values
248
+
249
+ current_states = encoder_hidden_states if is_cross_attention else hidden_states
250
+ if is_cross_attention and past_key_values is not None and is_updated:
251
+ # reuse k,v, cross_attentions
252
+ key_layer = curr_past_key_value.layers[self.layer_idx].keys
253
+ value_layer = curr_past_key_value.layers[self.layer_idx].values
254
+ else:
255
+ key_layer = self.key(current_states)
256
+ key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
257
+ 1, 2
258
+ )
259
+ value_layer = self.value(current_states)
260
+ value_layer = value_layer.view(
261
+ batch_size, -1, self.num_attention_heads, self.attention_head_size
262
+ ).transpose(1, 2)
263
+
264
+ if past_key_values is not None:
265
+ # save all key/value_layer to cache to be re-used for fast auto-regressive generation
266
+ cache_position = cache_position if not is_cross_attention else None
267
+ key_layer, value_layer = curr_past_key_value.update(
268
+ key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
269
+ )
270
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
271
+ if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
272
+ past_key_values.is_updated[self.layer_idx] = True
273
+
274
+ # Take the dot product between "query" and "key" to get the raw attention scores.
275
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
276
+
277
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
278
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
279
+ if past_key_values is not None:
280
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
281
+ -1, 1
282
+ )
283
+ else:
284
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
285
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
286
+ distance = position_ids_l - position_ids_r
287
+
288
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
289
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
290
+
291
+ if self.position_embedding_type == "relative_key":
292
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
293
+ attention_scores = attention_scores + relative_position_scores
294
+ elif self.position_embedding_type == "relative_key_query":
295
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
296
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
297
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
298
+
299
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
300
+ if attention_mask is not None:
301
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
302
+ attention_scores = attention_scores + attention_mask
303
+
304
+ # Normalize the attention scores to probabilities.
305
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
306
+
307
+ # This is actually dropping out entire tokens to attend to, which might
308
+ # seem a bit unusual, but is taken from the original Transformer paper.
309
+ attention_probs = self.dropout(attention_probs)
310
+
311
+ # Mask heads if we want to
312
+ if head_mask is not None:
313
+ attention_probs = attention_probs * head_mask
314
+
315
+ context_layer = torch.matmul(attention_probs, value_layer)
316
+
317
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
318
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
319
+ context_layer = context_layer.view(new_context_layer_shape)
320
+
321
+ return context_layer, attention_probs
322
+
323
+
324
+ class BertSdpaSelfAttention(BertSelfAttention):
325
+ def __init__(self, config, position_embedding_type=None, layer_idx=None):
326
+ super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx)
327
+ self.dropout_prob = config.attention_probs_dropout_prob
328
+
329
+ # Adapted from BertSelfAttention
330
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
331
+ def forward(
332
+ self,
333
+ hidden_states: torch.Tensor,
334
+ attention_mask: Optional[torch.Tensor] = None,
335
+ head_mask: Optional[torch.FloatTensor] = None,
336
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
337
+ past_key_values: Optional[Cache] = None,
338
+ output_attentions: Optional[bool] = False,
339
+ cache_position: Optional[torch.Tensor] = None,
340
+ ) -> tuple[torch.Tensor]:
341
+ if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
342
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
343
+ logger.warning_once(
344
+ "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
345
+ "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
346
+ "the manual attention implementation, but specifying the manual implementation will be required from "
347
+ "Transformers version v5.0.0 onwards. This warning can be removed using the argument "
348
+ '`attn_implementation="eager"` when loading the model.'
349
+ )
350
+ return super().forward(
351
+ hidden_states,
352
+ attention_mask,
353
+ head_mask,
354
+ encoder_hidden_states,
355
+ past_key_values,
356
+ output_attentions,
357
+ cache_position,
358
+ )
359
+
360
+ bsz, tgt_len, _ = hidden_states.size()
361
+
362
+ query_layer = (
363
+ self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
364
+ )
365
+
366
+ is_updated = False
367
+ is_cross_attention = encoder_hidden_states is not None
368
+ current_states = encoder_hidden_states if is_cross_attention else hidden_states
369
+ if past_key_values is not None:
370
+ if isinstance(past_key_values, EncoderDecoderCache):
371
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
372
+ if is_cross_attention:
373
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
374
+ curr_past_key_value = past_key_values.cross_attention_cache
375
+ else:
376
+ curr_past_key_value = past_key_values.self_attention_cache
377
+ else:
378
+ curr_past_key_value = past_key_values
379
+
380
+ current_states = encoder_hidden_states if is_cross_attention else hidden_states
381
+ if is_cross_attention and past_key_values is not None and is_updated:
382
+ # reuse k,v, cross_attentions
383
+ key_layer = curr_past_key_value.layers[self.layer_idx].keys
384
+ value_layer = curr_past_key_value.layers[self.layer_idx].values
385
+ else:
386
+ key_layer = (
387
+ self.key(current_states)
388
+ .view(bsz, -1, self.num_attention_heads, self.attention_head_size)
389
+ .transpose(1, 2)
390
+ )
391
+ value_layer = (
392
+ self.value(current_states)
393
+ .view(bsz, -1, self.num_attention_heads, self.attention_head_size)
394
+ .transpose(1, 2)
395
+ )
396
+
397
+ if past_key_values is not None:
398
+ # save all key/value_layer to cache to be re-used for fast auto-regressive generation
399
+ cache_position = cache_position if not is_cross_attention else None
400
+ key_layer, value_layer = curr_past_key_value.update(
401
+ key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
402
+ )
403
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
404
+ if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
405
+ past_key_values.is_updated[self.layer_idx] = True
406
+
407
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
408
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
409
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
410
+ # a causal mask in case tgt_len == 1.
411
+ is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1
412
+
413
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
414
+ query_layer,
415
+ key_layer,
416
+ value_layer,
417
+ attn_mask=attention_mask,
418
+ dropout_p=self.dropout_prob if self.training else 0.0,
419
+ is_causal=is_causal,
420
+ )
421
+
422
+ attn_output = attn_output.transpose(1, 2)
423
+ attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
424
+
425
+ return attn_output, None
426
+
427
+
428
+ class BertSelfOutput(nn.Module):
429
+ def __init__(self, config):
430
+ super().__init__()
431
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
432
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
433
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
434
+
435
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
436
+ hidden_states = self.dense(hidden_states)
437
+ hidden_states = self.dropout(hidden_states)
438
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
439
+ return hidden_states
440
+
441
+
442
+ BERT_SELF_ATTENTION_CLASSES = {
443
+ "eager": BertSelfAttention,
444
+ "sdpa": BertSdpaSelfAttention,
445
+ }
446
+
447
+
448
+ class BertAttention(nn.Module):
449
+ def __init__(self, config, position_embedding_type=None, layer_idx=None):
450
+ super().__init__()
451
+ self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
452
+ config,
453
+ position_embedding_type=position_embedding_type,
454
+ layer_idx=layer_idx,
455
+ )
456
+ self.output = BertSelfOutput(config)
457
+ self.pruned_heads = set()
458
+
459
+ def prune_heads(self, heads):
460
+ if len(heads) == 0:
461
+ return
462
+ heads, index = find_pruneable_heads_and_indices(
463
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
464
+ )
465
+
466
+ # Prune linear layers
467
+ self.self.query = prune_linear_layer(self.self.query, index)
468
+ self.self.key = prune_linear_layer(self.self.key, index)
469
+ self.self.value = prune_linear_layer(self.self.value, index)
470
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
471
+
472
+ # Update hyper params and store pruned heads
473
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
474
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
475
+ self.pruned_heads = self.pruned_heads.union(heads)
476
+
477
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
478
+ def forward(
479
+ self,
480
+ hidden_states: torch.Tensor,
481
+ attention_mask: Optional[torch.FloatTensor] = None,
482
+ head_mask: Optional[torch.FloatTensor] = None,
483
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
484
+ past_key_values: Optional[Cache] = None,
485
+ output_attentions: Optional[bool] = False,
486
+ cache_position: Optional[torch.Tensor] = None,
487
+ ) -> tuple[torch.Tensor]:
488
+ self_outputs = self.self(
489
+ hidden_states,
490
+ attention_mask=attention_mask,
491
+ head_mask=head_mask,
492
+ encoder_hidden_states=encoder_hidden_states,
493
+ past_key_values=past_key_values,
494
+ output_attentions=output_attentions,
495
+ cache_position=cache_position,
496
+ )
497
+ attention_output = self.output(self_outputs[0], hidden_states)
498
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
499
+ return outputs
500
+
501
+
502
+ class BertIntermediate(nn.Module):
503
+ def __init__(self, config):
504
+ super().__init__()
505
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
506
+ if isinstance(config.hidden_act, str):
507
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
508
+ else:
509
+ self.intermediate_act_fn = config.hidden_act
510
+
511
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
512
+ hidden_states = self.dense(hidden_states)
513
+ hidden_states = self.intermediate_act_fn(hidden_states)
514
+ return hidden_states
515
+
516
+
517
+ class BertOutput(nn.Module):
518
+ def __init__(self, config):
519
+ super().__init__()
520
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
521
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
522
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
523
+
524
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
525
+ hidden_states = self.dense(hidden_states)
526
+ hidden_states = self.dropout(hidden_states)
527
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
528
+ return hidden_states
529
+
530
+
531
+ class BertLayer(GradientCheckpointingLayer):
532
+ def __init__(self, config, layer_idx=None):
533
+ super().__init__()
534
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
535
+ self.seq_len_dim = 1
536
+ self.attention = BertAttention(config, layer_idx=layer_idx)
537
+ self.is_decoder = config.is_decoder
538
+ self.add_cross_attention = config.add_cross_attention
539
+ if self.add_cross_attention:
540
+ if not self.is_decoder:
541
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
542
+ self.crossattention = BertAttention(config, position_embedding_type="absolute", layer_idx=layer_idx)
543
+ self.intermediate = BertIntermediate(config)
544
+ self.output = BertOutput(config)
545
+
546
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
547
+ def forward(
548
+ self,
549
+ hidden_states: torch.Tensor,
550
+ attention_mask: Optional[torch.FloatTensor] = None,
551
+ head_mask: Optional[torch.FloatTensor] = None,
552
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
553
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
554
+ past_key_values: Optional[Cache] = None,
555
+ output_attentions: Optional[bool] = False,
556
+ cache_position: Optional[torch.Tensor] = None,
557
+ ) -> tuple[torch.Tensor]:
558
+ self_attention_outputs = self.attention(
559
+ hidden_states,
560
+ attention_mask=attention_mask,
561
+ head_mask=head_mask,
562
+ output_attentions=output_attentions,
563
+ past_key_values=past_key_values,
564
+ cache_position=cache_position,
565
+ )
566
+ attention_output = self_attention_outputs[0]
567
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
568
+
569
+ if self.is_decoder and encoder_hidden_states is not None:
570
+ if not hasattr(self, "crossattention"):
571
+ raise ValueError(
572
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
573
+ " by setting `config.add_cross_attention=True`"
574
+ )
575
+
576
+ cross_attention_outputs = self.crossattention(
577
+ attention_output,
578
+ attention_mask=encoder_attention_mask,
579
+ head_mask=head_mask,
580
+ encoder_hidden_states=encoder_hidden_states,
581
+ past_key_values=past_key_values,
582
+ output_attentions=output_attentions,
583
+ cache_position=cache_position,
584
+ )
585
+ attention_output = cross_attention_outputs[0]
586
+ outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
587
+
588
+ layer_output = apply_chunking_to_forward(
589
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
590
+ )
591
+ outputs = (layer_output,) + outputs
592
+
593
+ return outputs
594
+
595
+ def feed_forward_chunk(self, attention_output):
596
+ intermediate_output = self.intermediate(attention_output)
597
+ layer_output = self.output(intermediate_output, attention_output)
598
+ return layer_output
599
+
600
+
601
+ class BertEncoder(nn.Module):
602
+ def __init__(self, config, layer_idx=None):
603
+ super().__init__()
604
+ self.config = config
605
+ self.layer = nn.ModuleList([BertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
606
+ self.gradient_checkpointing = False
607
+
608
+ def forward(
609
+ self,
610
+ hidden_states: torch.Tensor,
611
+ attention_mask: Optional[torch.FloatTensor] = None,
612
+ head_mask: Optional[torch.FloatTensor] = None,
613
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
614
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
615
+ past_key_values: Optional[Cache] = None,
616
+ use_cache: Optional[bool] = None,
617
+ output_attentions: Optional[bool] = False,
618
+ output_hidden_states: Optional[bool] = False,
619
+ return_dict: Optional[bool] = True,
620
+ cache_position: Optional[torch.Tensor] = None,
621
+ ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
622
+ all_hidden_states = () if output_hidden_states else None
623
+ all_self_attentions = () if output_attentions else None
624
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
625
+
626
+ if self.gradient_checkpointing and self.training:
627
+ if use_cache:
628
+ logger.warning_once(
629
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
630
+ )
631
+ use_cache = False
632
+
633
+ if use_cache and self.config.is_decoder and past_key_values is None:
634
+ past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
635
+
636
+ if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
637
+ logger.warning_once(
638
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
639
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
640
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
641
+ )
642
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
643
+
644
+ for i, layer_module in enumerate(self.layer):
645
+ if output_hidden_states:
646
+ all_hidden_states = all_hidden_states + (hidden_states,)
647
+
648
+ layer_head_mask = head_mask[i] if head_mask is not None else None
649
+
650
+ layer_outputs = layer_module(
651
+ hidden_states,
652
+ attention_mask,
653
+ layer_head_mask,
654
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
655
+ encoder_attention_mask=encoder_attention_mask,
656
+ past_key_values=past_key_values,
657
+ output_attentions=output_attentions,
658
+ cache_position=cache_position,
659
+ )
660
+
661
+ hidden_states = layer_outputs[0]
662
+ if output_attentions:
663
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
664
+ if self.config.add_cross_attention:
665
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
666
+
667
+ if output_hidden_states:
668
+ all_hidden_states = all_hidden_states + (hidden_states,)
669
+
670
+ if not return_dict:
671
+ return tuple(
672
+ v
673
+ for v in [
674
+ hidden_states,
675
+ past_key_values,
676
+ all_hidden_states,
677
+ all_self_attentions,
678
+ all_cross_attentions,
679
+ ]
680
+ if v is not None
681
+ )
682
+ return BaseModelOutputWithPastAndCrossAttentions(
683
+ last_hidden_state=hidden_states,
684
+ past_key_values=past_key_values,
685
+ hidden_states=all_hidden_states,
686
+ attentions=all_self_attentions,
687
+ cross_attentions=all_cross_attentions,
688
+ )
689
+
690
+
691
+ class BertPooler(nn.Module):
692
+ def __init__(self, config):
693
+ super().__init__()
694
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
695
+ self.activation = nn.Tanh()
696
+
697
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
698
+ # We "pool" the model by simply taking the hidden state corresponding
699
+ # to the first token.
700
+ first_token_tensor = hidden_states[:, 0]
701
+ pooled_output = self.dense(first_token_tensor)
702
+ pooled_output = self.activation(pooled_output)
703
+ return pooled_output
704
+
705
+
706
+ class BertPredictionHeadTransform(nn.Module):
707
+ def __init__(self, config):
708
+ super().__init__()
709
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
710
+ if isinstance(config.hidden_act, str):
711
+ self.transform_act_fn = ACT2FN[config.hidden_act]
712
+ else:
713
+ self.transform_act_fn = config.hidden_act
714
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
715
+
716
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
717
+ hidden_states = self.dense(hidden_states)
718
+ hidden_states = self.transform_act_fn(hidden_states)
719
+ hidden_states = self.LayerNorm(hidden_states)
720
+ return hidden_states
721
+
722
+
723
+ class BertLMPredictionHead(nn.Module):
724
+ def __init__(self, config):
725
+ super().__init__()
726
+ self.transform = BertPredictionHeadTransform(config)
727
+
728
+ # The output weights are the same as the input embeddings, but there is
729
+ # an output-only bias for each token.
730
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
731
+
732
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
733
+
734
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
735
+ self.decoder.bias = self.bias
736
+
737
+ def _tie_weights(self):
738
+ self.decoder.bias = self.bias
739
+
740
+ def forward(self, hidden_states):
741
+ hidden_states = self.transform(hidden_states)
742
+ hidden_states = self.decoder(hidden_states)
743
+ return hidden_states
744
+
745
+
746
+ class BertOnlyMLMHead(nn.Module):
747
+ def __init__(self, config):
748
+ super().__init__()
749
+ self.predictions = BertLMPredictionHead(config)
750
+
751
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
752
+ prediction_scores = self.predictions(sequence_output)
753
+ return prediction_scores
754
+
755
+
756
+ class BertOnlyNSPHead(nn.Module):
757
+ def __init__(self, config):
758
+ super().__init__()
759
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
760
+
761
+ def forward(self, pooled_output):
762
+ seq_relationship_score = self.seq_relationship(pooled_output)
763
+ return seq_relationship_score
764
+
765
+
766
+ class BertPreTrainingHeads(nn.Module):
767
+ def __init__(self, config):
768
+ super().__init__()
769
+ self.predictions = BertLMPredictionHead(config)
770
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
771
+
772
+ def forward(self, sequence_output, pooled_output):
773
+ prediction_scores = self.predictions(sequence_output)
774
+ seq_relationship_score = self.seq_relationship(pooled_output)
775
+ return prediction_scores, seq_relationship_score
776
+
777
+
778
+ @auto_docstring
779
+ class BertPreTrainedModel(PreTrainedModel):
780
+ config: BertConfig
781
+ load_tf_weights = load_tf_weights_in_bert
782
+ base_model_prefix = "bert"
783
+ supports_gradient_checkpointing = True
784
+ _supports_sdpa = True
785
+
786
+ def _init_weights(self, module):
787
+ """Initialize the weights"""
788
+ if isinstance(module, nn.Linear):
789
+ # Slightly different from the TF version which uses truncated_normal for initialization
790
+ # cf https://github.com/pytorch/pytorch/pull/5617
791
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
792
+ if module.bias is not None:
793
+ module.bias.data.zero_()
794
+ elif isinstance(module, nn.Embedding):
795
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
796
+ if module.padding_idx is not None:
797
+ module.weight.data[module.padding_idx].zero_()
798
+ elif isinstance(module, nn.LayerNorm):
799
+ module.bias.data.zero_()
800
+ module.weight.data.fill_(1.0)
801
+ elif isinstance(module, BertLMPredictionHead):
802
+ module.bias.data.zero_()
803
+
804
+
805
+ @dataclass
806
+ @auto_docstring(
807
+ custom_intro="""
808
+ Output type of [`BertForPreTraining`].
809
+ """
810
+ )
811
+ class BertForPreTrainingOutput(ModelOutput):
812
+ r"""
813
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
814
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
815
+ (classification) loss.
816
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
817
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
818
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
819
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
820
+ before SoftMax).
821
+ """
822
+
823
+ loss: Optional[torch.FloatTensor] = None
824
+ prediction_logits: Optional[torch.FloatTensor] = None
825
+ seq_relationship_logits: Optional[torch.FloatTensor] = None
826
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
827
+ attentions: Optional[tuple[torch.FloatTensor]] = None
828
+
829
+
830
+ @auto_docstring(
831
+ custom_intro="""
832
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
833
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
834
+ all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
835
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
836
+
837
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
838
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
839
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
840
+ """
841
+ )
842
+ class BertModel(BertPreTrainedModel):
843
+ _no_split_modules = ["BertEmbeddings", "BertLayer"]
844
+
845
+ def __init__(self, config, add_pooling_layer=True):
846
+ r"""
847
+ add_pooling_layer (bool, *optional*, defaults to `True`):
848
+ Whether to add a pooling layer
849
+ """
850
+ super().__init__(config)
851
+ self.config = config
852
+
853
+ self.embeddings = BertEmbeddings(config)
854
+ self.encoder = BertEncoder(config)
855
+
856
+ self.pooler = BertPooler(config) if add_pooling_layer else None
857
+
858
+ self.attn_implementation = config._attn_implementation
859
+ self.position_embedding_type = config.position_embedding_type
860
+
861
+ # Initialize weights and apply final processing
862
+ self.post_init()
863
+
864
+ def get_input_embeddings(self):
865
+ return self.embeddings.word_embeddings
866
+
867
+ def set_input_embeddings(self, value):
868
+ self.embeddings.word_embeddings = value
869
+
870
+ def _prune_heads(self, heads_to_prune):
871
+ """
872
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
873
+ class PreTrainedModel
874
+ """
875
+ for layer, heads in heads_to_prune.items():
876
+ self.encoder.layer[layer].attention.prune_heads(heads)
877
+
878
+ @auto_docstring
879
+ def forward(
880
+ self,
881
+ input_ids: Optional[torch.Tensor] = None,
882
+ attention_mask: Optional[torch.Tensor] = None,
883
+ token_type_ids: Optional[torch.Tensor] = None,
884
+ position_ids: Optional[torch.Tensor] = None,
885
+ head_mask: Optional[torch.Tensor] = None,
886
+ inputs_embeds: Optional[torch.Tensor] = None,
887
+ encoder_hidden_states: Optional[torch.Tensor] = None,
888
+ encoder_attention_mask: Optional[torch.Tensor] = None,
889
+ past_key_values: Optional[Cache] = None,
890
+ use_cache: Optional[bool] = None,
891
+ output_attentions: Optional[bool] = None,
892
+ output_hidden_states: Optional[bool] = None,
893
+ return_dict: Optional[bool] = None,
894
+ cache_position: Optional[torch.Tensor] = None,
895
+ ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
896
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
897
+ output_hidden_states = (
898
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
899
+ )
900
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
901
+
902
+ if self.config.is_decoder:
903
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
904
+ else:
905
+ use_cache = False
906
+
907
+ if input_ids is not None and inputs_embeds is not None:
908
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
909
+ elif input_ids is not None:
910
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
911
+ input_shape = input_ids.size()
912
+ elif inputs_embeds is not None:
913
+ input_shape = inputs_embeds.size()[:-1]
914
+ else:
915
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
916
+
917
+ batch_size, seq_length = input_shape
918
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
919
+
920
+ past_key_values_length = 0
921
+ if past_key_values is not None:
922
+ past_key_values_length = (
923
+ past_key_values[0][0].shape[-2]
924
+ if not isinstance(past_key_values, Cache)
925
+ else past_key_values.get_seq_length()
926
+ )
927
+
928
+ if token_type_ids is None:
929
+ if hasattr(self.embeddings, "token_type_ids"):
930
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
931
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
932
+ token_type_ids = buffered_token_type_ids_expanded
933
+ else:
934
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
935
+
936
+ embedding_output = self.embeddings(
937
+ input_ids=input_ids,
938
+ position_ids=position_ids,
939
+ token_type_ids=token_type_ids,
940
+ inputs_embeds=inputs_embeds,
941
+ past_key_values_length=past_key_values_length,
942
+ )
943
+
944
+ if attention_mask is None:
945
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
946
+
947
+ use_sdpa_attention_masks = (
948
+ self.attn_implementation == "sdpa"
949
+ and self.position_embedding_type == "absolute"
950
+ and head_mask is None
951
+ and not output_attentions
952
+ )
953
+
954
+ # Expand the attention mask
955
+ if use_sdpa_attention_masks and attention_mask.dim() == 2:
956
+ # Expand the attention mask for SDPA.
957
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
958
+ if self.config.is_decoder:
959
+ extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
960
+ attention_mask,
961
+ input_shape,
962
+ embedding_output,
963
+ past_key_values_length,
964
+ )
965
+ else:
966
+ extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
967
+ attention_mask, embedding_output.dtype, tgt_len=seq_length
968
+ )
969
+ else:
970
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
971
+ # ourselves in which case we just need to make it broadcastable to all heads.
972
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
973
+
974
+ # If a 2D or 3D attention mask is provided for the cross-attention
975
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
976
+ if self.config.is_decoder and encoder_hidden_states is not None:
977
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
978
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
979
+ if encoder_attention_mask is None:
980
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
981
+
982
+ if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
983
+ # Expand the attention mask for SDPA.
984
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
985
+ encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
986
+ encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
987
+ )
988
+ else:
989
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
990
+ else:
991
+ encoder_extended_attention_mask = None
992
+
993
+ # Prepare head mask if needed
994
+ # 1.0 in head_mask indicate we keep the head
995
+ # attention_probs has shape bsz x n_heads x N x N
996
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
997
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
998
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
999
+
1000
+ encoder_outputs = self.encoder(
1001
+ embedding_output,
1002
+ attention_mask=extended_attention_mask,
1003
+ head_mask=head_mask,
1004
+ encoder_hidden_states=encoder_hidden_states,
1005
+ encoder_attention_mask=encoder_extended_attention_mask,
1006
+ past_key_values=past_key_values,
1007
+ use_cache=use_cache,
1008
+ output_attentions=output_attentions,
1009
+ output_hidden_states=output_hidden_states,
1010
+ return_dict=return_dict,
1011
+ cache_position=cache_position,
1012
+ )
1013
+ sequence_output = encoder_outputs[0]
1014
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1015
+
1016
+ if not return_dict:
1017
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1018
+
1019
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1020
+ last_hidden_state=sequence_output,
1021
+ pooler_output=pooled_output,
1022
+ past_key_values=encoder_outputs.past_key_values,
1023
+ hidden_states=encoder_outputs.hidden_states,
1024
+ attentions=encoder_outputs.attentions,
1025
+ cross_attentions=encoder_outputs.cross_attentions,
1026
+ )
1027
+
1028
+
1029
+ @auto_docstring(
1030
+ custom_intro="""
1031
+ Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
1032
+ sentence prediction (classification)` head.
1033
+ """
1034
+ )
1035
+ class BertForPreTraining(BertPreTrainedModel):
1036
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1037
+
1038
+ def __init__(self, config):
1039
+ super().__init__(config)
1040
+
1041
+ self.bert = BertModel(config)
1042
+ self.cls = BertPreTrainingHeads(config)
1043
+
1044
+ # Initialize weights and apply final processing
1045
+ self.post_init()
1046
+
1047
+ def get_output_embeddings(self):
1048
+ return self.cls.predictions.decoder
1049
+
1050
+ def set_output_embeddings(self, new_embeddings):
1051
+ self.cls.predictions.decoder = new_embeddings
1052
+ self.cls.predictions.bias = new_embeddings.bias
1053
+
1054
+ @auto_docstring
1055
+ def forward(
1056
+ self,
1057
+ input_ids: Optional[torch.Tensor] = None,
1058
+ attention_mask: Optional[torch.Tensor] = None,
1059
+ token_type_ids: Optional[torch.Tensor] = None,
1060
+ position_ids: Optional[torch.Tensor] = None,
1061
+ head_mask: Optional[torch.Tensor] = None,
1062
+ inputs_embeds: Optional[torch.Tensor] = None,
1063
+ labels: Optional[torch.Tensor] = None,
1064
+ next_sentence_label: Optional[torch.Tensor] = None,
1065
+ output_attentions: Optional[bool] = None,
1066
+ output_hidden_states: Optional[bool] = None,
1067
+ return_dict: Optional[bool] = None,
1068
+ ) -> Union[tuple[torch.Tensor], BertForPreTrainingOutput]:
1069
+ r"""
1070
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1071
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1072
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
1073
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1074
+ next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1075
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
1076
+ pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
1077
+
1078
+ - 0 indicates sequence B is a continuation of sequence A,
1079
+ - 1 indicates sequence B is a random sequence.
1080
+
1081
+ Example:
1082
+
1083
+ ```python
1084
+ >>> from transformers import AutoTokenizer, BertForPreTraining
1085
+ >>> import torch
1086
+
1087
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
1088
+ >>> model = BertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
1089
+
1090
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1091
+ >>> outputs = model(**inputs)
1092
+
1093
+ >>> prediction_logits = outputs.prediction_logits
1094
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
1095
+ ```
1096
+ """
1097
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1098
+
1099
+ outputs = self.bert(
1100
+ input_ids,
1101
+ attention_mask=attention_mask,
1102
+ token_type_ids=token_type_ids,
1103
+ position_ids=position_ids,
1104
+ head_mask=head_mask,
1105
+ inputs_embeds=inputs_embeds,
1106
+ output_attentions=output_attentions,
1107
+ output_hidden_states=output_hidden_states,
1108
+ return_dict=return_dict,
1109
+ )
1110
+
1111
+ sequence_output, pooled_output = outputs[:2]
1112
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
1113
+
1114
+ total_loss = None
1115
+ if labels is not None and next_sentence_label is not None:
1116
+ loss_fct = CrossEntropyLoss()
1117
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1118
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1119
+ total_loss = masked_lm_loss + next_sentence_loss
1120
+
1121
+ if not return_dict:
1122
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
1123
+ return ((total_loss,) + output) if total_loss is not None else output
1124
+
1125
+ return BertForPreTrainingOutput(
1126
+ loss=total_loss,
1127
+ prediction_logits=prediction_scores,
1128
+ seq_relationship_logits=seq_relationship_score,
1129
+ hidden_states=outputs.hidden_states,
1130
+ attentions=outputs.attentions,
1131
+ )
1132
+
1133
+
1134
+ @auto_docstring(
1135
+ custom_intro="""
1136
+ Bert Model with a `language modeling` head on top for CLM fine-tuning.
1137
+ """
1138
+ )
1139
+ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
1140
+ _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
1141
+
1142
+ def __init__(self, config):
1143
+ super().__init__(config)
1144
+
1145
+ if not config.is_decoder:
1146
+ logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
1147
+
1148
+ self.bert = BertModel(config, add_pooling_layer=False)
1149
+ self.cls = BertOnlyMLMHead(config)
1150
+
1151
+ # Initialize weights and apply final processing
1152
+ self.post_init()
1153
+
1154
+ def get_output_embeddings(self):
1155
+ return self.cls.predictions.decoder
1156
+
1157
+ def set_output_embeddings(self, new_embeddings):
1158
+ self.cls.predictions.decoder = new_embeddings
1159
+ self.cls.predictions.bias = new_embeddings.bias
1160
+
1161
+ @auto_docstring
1162
+ def forward(
1163
+ self,
1164
+ input_ids: Optional[torch.Tensor] = None,
1165
+ attention_mask: Optional[torch.Tensor] = None,
1166
+ token_type_ids: Optional[torch.Tensor] = None,
1167
+ position_ids: Optional[torch.Tensor] = None,
1168
+ head_mask: Optional[torch.Tensor] = None,
1169
+ inputs_embeds: Optional[torch.Tensor] = None,
1170
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1171
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1172
+ labels: Optional[torch.Tensor] = None,
1173
+ past_key_values: Optional[Cache] = None,
1174
+ use_cache: Optional[bool] = None,
1175
+ output_attentions: Optional[bool] = None,
1176
+ output_hidden_states: Optional[bool] = None,
1177
+ return_dict: Optional[bool] = None,
1178
+ cache_position: Optional[torch.Tensor] = None,
1179
+ **loss_kwargs,
1180
+ ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
1181
+ r"""
1182
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1183
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1184
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1185
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
1186
+ """
1187
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1188
+ if labels is not None:
1189
+ use_cache = False
1190
+
1191
+ outputs = self.bert(
1192
+ input_ids,
1193
+ attention_mask=attention_mask,
1194
+ token_type_ids=token_type_ids,
1195
+ position_ids=position_ids,
1196
+ head_mask=head_mask,
1197
+ inputs_embeds=inputs_embeds,
1198
+ encoder_hidden_states=encoder_hidden_states,
1199
+ encoder_attention_mask=encoder_attention_mask,
1200
+ past_key_values=past_key_values,
1201
+ use_cache=use_cache,
1202
+ output_attentions=output_attentions,
1203
+ output_hidden_states=output_hidden_states,
1204
+ return_dict=return_dict,
1205
+ cache_position=cache_position,
1206
+ )
1207
+
1208
+ sequence_output = outputs[0]
1209
+ prediction_scores = self.cls(sequence_output)
1210
+
1211
+ lm_loss = None
1212
+ if labels is not None:
1213
+ lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs)
1214
+
1215
+ if not return_dict:
1216
+ output = (prediction_scores,) + outputs[2:]
1217
+ return ((lm_loss,) + output) if lm_loss is not None else output
1218
+
1219
+ return CausalLMOutputWithCrossAttentions(
1220
+ loss=lm_loss,
1221
+ logits=prediction_scores,
1222
+ past_key_values=outputs.past_key_values,
1223
+ hidden_states=outputs.hidden_states,
1224
+ attentions=outputs.attentions,
1225
+ cross_attentions=outputs.cross_attentions,
1226
+ )
1227
+
1228
+
1229
+ @auto_docstring
1230
+ class BertForMaskedLM(BertPreTrainedModel):
1231
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1232
+
1233
+ def __init__(self, config):
1234
+ super().__init__(config)
1235
+
1236
+ if config.is_decoder:
1237
+ logger.warning(
1238
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
1239
+ "bi-directional self-attention."
1240
+ )
1241
+
1242
+ self.bert = BertModel(config, add_pooling_layer=False)
1243
+ self.cls = BertOnlyMLMHead(config)
1244
+
1245
+ # Initialize weights and apply final processing
1246
+ self.post_init()
1247
+
1248
+ def get_output_embeddings(self):
1249
+ return self.cls.predictions.decoder
1250
+
1251
+ def set_output_embeddings(self, new_embeddings):
1252
+ self.cls.predictions.decoder = new_embeddings
1253
+ self.cls.predictions.bias = new_embeddings.bias
1254
+
1255
+ @auto_docstring
1256
+ def forward(
1257
+ self,
1258
+ input_ids: Optional[torch.Tensor] = None,
1259
+ attention_mask: Optional[torch.Tensor] = None,
1260
+ token_type_ids: Optional[torch.Tensor] = None,
1261
+ position_ids: Optional[torch.Tensor] = None,
1262
+ head_mask: Optional[torch.Tensor] = None,
1263
+ inputs_embeds: Optional[torch.Tensor] = None,
1264
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1265
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1266
+ labels: Optional[torch.Tensor] = None,
1267
+ output_attentions: Optional[bool] = None,
1268
+ output_hidden_states: Optional[bool] = None,
1269
+ return_dict: Optional[bool] = None,
1270
+ ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
1271
+ r"""
1272
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1273
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1274
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1275
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1276
+ """
1277
+
1278
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1279
+
1280
+ outputs = self.bert(
1281
+ input_ids,
1282
+ attention_mask=attention_mask,
1283
+ token_type_ids=token_type_ids,
1284
+ position_ids=position_ids,
1285
+ head_mask=head_mask,
1286
+ inputs_embeds=inputs_embeds,
1287
+ encoder_hidden_states=encoder_hidden_states,
1288
+ encoder_attention_mask=encoder_attention_mask,
1289
+ output_attentions=output_attentions,
1290
+ output_hidden_states=output_hidden_states,
1291
+ return_dict=return_dict,
1292
+ )
1293
+
1294
+ sequence_output = outputs[0]
1295
+ prediction_scores = self.cls(sequence_output)
1296
+
1297
+ masked_lm_loss = None
1298
+ if labels is not None:
1299
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1300
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1301
+
1302
+ if not return_dict:
1303
+ output = (prediction_scores,) + outputs[2:]
1304
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1305
+
1306
+ return MaskedLMOutput(
1307
+ loss=masked_lm_loss,
1308
+ logits=prediction_scores,
1309
+ hidden_states=outputs.hidden_states,
1310
+ attentions=outputs.attentions,
1311
+ )
1312
+
1313
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1314
+ input_shape = input_ids.shape
1315
+ effective_batch_size = input_shape[0]
1316
+
1317
+ # add a dummy token
1318
+ if self.config.pad_token_id is None:
1319
+ raise ValueError("The PAD token should be defined for generation")
1320
+
1321
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1322
+ dummy_token = torch.full(
1323
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1324
+ )
1325
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1326
+
1327
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1328
+
1329
+ @classmethod
1330
+ def can_generate(cls) -> bool:
1331
+ """
1332
+ Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
1333
+ `prepare_inputs_for_generation` method.
1334
+ """
1335
+ return False
1336
+
1337
+
1338
+ @auto_docstring(
1339
+ custom_intro="""
1340
+ Bert Model with a `next sentence prediction (classification)` head on top.
1341
+ """
1342
+ )
1343
+ class BertForNextSentencePrediction(BertPreTrainedModel):
1344
+ def __init__(self, config):
1345
+ super().__init__(config)
1346
+
1347
+ self.bert = BertModel(config)
1348
+ self.cls = BertOnlyNSPHead(config)
1349
+
1350
+ # Initialize weights and apply final processing
1351
+ self.post_init()
1352
+
1353
+ @auto_docstring
1354
+ def forward(
1355
+ self,
1356
+ input_ids: Optional[torch.Tensor] = None,
1357
+ attention_mask: Optional[torch.Tensor] = None,
1358
+ token_type_ids: Optional[torch.Tensor] = None,
1359
+ position_ids: Optional[torch.Tensor] = None,
1360
+ head_mask: Optional[torch.Tensor] = None,
1361
+ inputs_embeds: Optional[torch.Tensor] = None,
1362
+ labels: Optional[torch.Tensor] = None,
1363
+ output_attentions: Optional[bool] = None,
1364
+ output_hidden_states: Optional[bool] = None,
1365
+ return_dict: Optional[bool] = None,
1366
+ **kwargs,
1367
+ ) -> Union[tuple[torch.Tensor], NextSentencePredictorOutput]:
1368
+ r"""
1369
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1370
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1371
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
1372
+
1373
+ - 0 indicates sequence B is a continuation of sequence A,
1374
+ - 1 indicates sequence B is a random sequence.
1375
+
1376
+ Example:
1377
+
1378
+ ```python
1379
+ >>> from transformers import AutoTokenizer, BertForNextSentencePrediction
1380
+ >>> import torch
1381
+
1382
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
1383
+ >>> model = BertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
1384
+
1385
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1386
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1387
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
1388
+
1389
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
1390
+ >>> logits = outputs.logits
1391
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1392
+ ```
1393
+ """
1394
+
1395
+ if "next_sentence_label" in kwargs:
1396
+ warnings.warn(
1397
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
1398
+ " `labels` instead.",
1399
+ FutureWarning,
1400
+ )
1401
+ labels = kwargs.pop("next_sentence_label")
1402
+
1403
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1404
+
1405
+ outputs = self.bert(
1406
+ input_ids,
1407
+ attention_mask=attention_mask,
1408
+ token_type_ids=token_type_ids,
1409
+ position_ids=position_ids,
1410
+ head_mask=head_mask,
1411
+ inputs_embeds=inputs_embeds,
1412
+ output_attentions=output_attentions,
1413
+ output_hidden_states=output_hidden_states,
1414
+ return_dict=return_dict,
1415
+ )
1416
+
1417
+ pooled_output = outputs[1]
1418
+
1419
+ seq_relationship_scores = self.cls(pooled_output)
1420
+
1421
+ next_sentence_loss = None
1422
+ if labels is not None:
1423
+ loss_fct = CrossEntropyLoss()
1424
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
1425
+
1426
+ if not return_dict:
1427
+ output = (seq_relationship_scores,) + outputs[2:]
1428
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1429
+
1430
+ return NextSentencePredictorOutput(
1431
+ loss=next_sentence_loss,
1432
+ logits=seq_relationship_scores,
1433
+ hidden_states=outputs.hidden_states,
1434
+ attentions=outputs.attentions,
1435
+ )
1436
+
1437
+
1438
+ @auto_docstring(
1439
+ custom_intro="""
1440
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1441
+ output) e.g. for GLUE tasks.
1442
+ """
1443
+ )
1444
+ class BertForSequenceClassification(BertPreTrainedModel):
1445
+ def __init__(self, config):
1446
+ super().__init__(config)
1447
+ self.num_labels = config.num_labels
1448
+ self.config = config
1449
+
1450
+ self.bert = BertModel(config)
1451
+ classifier_dropout = (
1452
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1453
+ )
1454
+ self.dropout = nn.Dropout(classifier_dropout)
1455
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1456
+
1457
+ # Initialize weights and apply final processing
1458
+ self.post_init()
1459
+
1460
+ @auto_docstring
1461
+ def forward(
1462
+ self,
1463
+ input_ids: Optional[torch.Tensor] = None,
1464
+ attention_mask: Optional[torch.Tensor] = None,
1465
+ token_type_ids: Optional[torch.Tensor] = None,
1466
+ position_ids: Optional[torch.Tensor] = None,
1467
+ head_mask: Optional[torch.Tensor] = None,
1468
+ inputs_embeds: Optional[torch.Tensor] = None,
1469
+ labels: Optional[torch.Tensor] = None,
1470
+ output_attentions: Optional[bool] = None,
1471
+ output_hidden_states: Optional[bool] = None,
1472
+ return_dict: Optional[bool] = None,
1473
+ ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
1474
+ r"""
1475
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1476
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1477
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1478
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1479
+ """
1480
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1481
+
1482
+ outputs = self.bert(
1483
+ input_ids,
1484
+ attention_mask=attention_mask,
1485
+ token_type_ids=token_type_ids,
1486
+ position_ids=position_ids,
1487
+ head_mask=head_mask,
1488
+ inputs_embeds=inputs_embeds,
1489
+ output_attentions=output_attentions,
1490
+ output_hidden_states=output_hidden_states,
1491
+ return_dict=return_dict,
1492
+ )
1493
+
1494
+ pooled_output = outputs[1]
1495
+
1496
+ pooled_output = self.dropout(pooled_output)
1497
+ logits = self.classifier(pooled_output)
1498
+
1499
+ loss = None
1500
+ if labels is not None:
1501
+ if self.config.problem_type is None:
1502
+ if self.num_labels == 1:
1503
+ self.config.problem_type = "regression"
1504
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1505
+ self.config.problem_type = "single_label_classification"
1506
+ else:
1507
+ self.config.problem_type = "multi_label_classification"
1508
+
1509
+ if self.config.problem_type == "regression":
1510
+ loss_fct = MSELoss()
1511
+ if self.num_labels == 1:
1512
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1513
+ else:
1514
+ loss = loss_fct(logits, labels)
1515
+ elif self.config.problem_type == "single_label_classification":
1516
+ loss_fct = CrossEntropyLoss()
1517
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1518
+ elif self.config.problem_type == "multi_label_classification":
1519
+ loss_fct = BCEWithLogitsLoss()
1520
+ loss = loss_fct(logits, labels)
1521
+ if not return_dict:
1522
+ output = (logits,) + outputs[2:]
1523
+ return ((loss,) + output) if loss is not None else output
1524
+
1525
+ return SequenceClassifierOutput(
1526
+ loss=loss,
1527
+ logits=logits,
1528
+ hidden_states=outputs.hidden_states,
1529
+ attentions=outputs.attentions,
1530
+ )
1531
+
1532
+
1533
+ @auto_docstring
1534
+ class BertForMultipleChoice(BertPreTrainedModel):
1535
+ def __init__(self, config):
1536
+ super().__init__(config)
1537
+
1538
+ self.bert = BertModel(config)
1539
+ classifier_dropout = (
1540
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1541
+ )
1542
+ self.dropout = nn.Dropout(classifier_dropout)
1543
+ self.classifier = nn.Linear(config.hidden_size, 1)
1544
+
1545
+ # Initialize weights and apply final processing
1546
+ self.post_init()
1547
+
1548
+ @auto_docstring
1549
+ def forward(
1550
+ self,
1551
+ input_ids: Optional[torch.Tensor] = None,
1552
+ attention_mask: Optional[torch.Tensor] = None,
1553
+ token_type_ids: Optional[torch.Tensor] = None,
1554
+ position_ids: Optional[torch.Tensor] = None,
1555
+ head_mask: Optional[torch.Tensor] = None,
1556
+ inputs_embeds: Optional[torch.Tensor] = None,
1557
+ labels: Optional[torch.Tensor] = None,
1558
+ output_attentions: Optional[bool] = None,
1559
+ output_hidden_states: Optional[bool] = None,
1560
+ return_dict: Optional[bool] = None,
1561
+ ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
1562
+ r"""
1563
+ input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
1564
+ Indices of input sequence tokens in the vocabulary.
1565
+
1566
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1567
+ [`PreTrainedTokenizer.__call__`] for details.
1568
+
1569
+ [What are input IDs?](../glossary#input-ids)
1570
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
1571
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
1572
+ 1]`:
1573
+
1574
+ - 0 corresponds to a *sentence A* token,
1575
+ - 1 corresponds to a *sentence B* token.
1576
+
1577
+ [What are token type IDs?](../glossary#token-type-ids)
1578
+ position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
1579
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1580
+ config.max_position_embeddings - 1]`.
1581
+
1582
+ [What are position IDs?](../glossary#position-ids)
1583
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
1584
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1585
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1586
+ model's internal embedding lookup matrix.
1587
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1588
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1589
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1590
+ `input_ids` above)
1591
+ """
1592
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1593
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1594
+
1595
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1596
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1597
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1598
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1599
+ inputs_embeds = (
1600
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1601
+ if inputs_embeds is not None
1602
+ else None
1603
+ )
1604
+
1605
+ outputs = self.bert(
1606
+ input_ids,
1607
+ attention_mask=attention_mask,
1608
+ token_type_ids=token_type_ids,
1609
+ position_ids=position_ids,
1610
+ head_mask=head_mask,
1611
+ inputs_embeds=inputs_embeds,
1612
+ output_attentions=output_attentions,
1613
+ output_hidden_states=output_hidden_states,
1614
+ return_dict=return_dict,
1615
+ )
1616
+
1617
+ pooled_output = outputs[1]
1618
+
1619
+ pooled_output = self.dropout(pooled_output)
1620
+ logits = self.classifier(pooled_output)
1621
+ reshaped_logits = logits.view(-1, num_choices)
1622
+
1623
+ loss = None
1624
+ if labels is not None:
1625
+ loss_fct = CrossEntropyLoss()
1626
+ loss = loss_fct(reshaped_logits, labels)
1627
+
1628
+ if not return_dict:
1629
+ output = (reshaped_logits,) + outputs[2:]
1630
+ return ((loss,) + output) if loss is not None else output
1631
+
1632
+ return MultipleChoiceModelOutput(
1633
+ loss=loss,
1634
+ logits=reshaped_logits,
1635
+ hidden_states=outputs.hidden_states,
1636
+ attentions=outputs.attentions,
1637
+ )
1638
+
1639
+
1640
+ @auto_docstring
1641
+ class BertForTokenClassification(BertPreTrainedModel):
1642
+ def __init__(self, config):
1643
+ super().__init__(config)
1644
+ self.num_labels = config.num_labels
1645
+
1646
+ self.bert = BertModel(config, add_pooling_layer=False)
1647
+ classifier_dropout = (
1648
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1649
+ )
1650
+ self.dropout = nn.Dropout(classifier_dropout)
1651
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1652
+
1653
+ # Initialize weights and apply final processing
1654
+ self.post_init()
1655
+
1656
+ @auto_docstring
1657
+ def forward(
1658
+ self,
1659
+ input_ids: Optional[torch.Tensor] = None,
1660
+ attention_mask: Optional[torch.Tensor] = None,
1661
+ token_type_ids: Optional[torch.Tensor] = None,
1662
+ position_ids: Optional[torch.Tensor] = None,
1663
+ head_mask: Optional[torch.Tensor] = None,
1664
+ inputs_embeds: Optional[torch.Tensor] = None,
1665
+ labels: Optional[torch.Tensor] = None,
1666
+ output_attentions: Optional[bool] = None,
1667
+ output_hidden_states: Optional[bool] = None,
1668
+ return_dict: Optional[bool] = None,
1669
+ ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
1670
+ r"""
1671
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1672
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1673
+ """
1674
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1675
+
1676
+ outputs = self.bert(
1677
+ input_ids,
1678
+ attention_mask=attention_mask,
1679
+ token_type_ids=token_type_ids,
1680
+ position_ids=position_ids,
1681
+ head_mask=head_mask,
1682
+ inputs_embeds=inputs_embeds,
1683
+ output_attentions=output_attentions,
1684
+ output_hidden_states=output_hidden_states,
1685
+ return_dict=return_dict,
1686
+ )
1687
+
1688
+ sequence_output = outputs[0]
1689
+
1690
+ sequence_output = self.dropout(sequence_output)
1691
+ logits = self.classifier(sequence_output)
1692
+
1693
+ loss = None
1694
+ if labels is not None:
1695
+ loss_fct = CrossEntropyLoss()
1696
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1697
+
1698
+ if not return_dict:
1699
+ output = (logits,) + outputs[2:]
1700
+ return ((loss,) + output) if loss is not None else output
1701
+
1702
+ return TokenClassifierOutput(
1703
+ loss=loss,
1704
+ logits=logits,
1705
+ hidden_states=outputs.hidden_states,
1706
+ attentions=outputs.attentions,
1707
+ )
1708
+
1709
+
1710
+ @auto_docstring
1711
+ class BertForQuestionAnswering(BertPreTrainedModel):
1712
+ def __init__(self, config):
1713
+ super().__init__(config)
1714
+ self.num_labels = config.num_labels
1715
+
1716
+ self.bert = BertModel(config, add_pooling_layer=False)
1717
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1718
+
1719
+ # Initialize weights and apply final processing
1720
+ self.post_init()
1721
+
1722
+ @auto_docstring
1723
+ def forward(
1724
+ self,
1725
+ input_ids: Optional[torch.Tensor] = None,
1726
+ attention_mask: Optional[torch.Tensor] = None,
1727
+ token_type_ids: Optional[torch.Tensor] = None,
1728
+ position_ids: Optional[torch.Tensor] = None,
1729
+ head_mask: Optional[torch.Tensor] = None,
1730
+ inputs_embeds: Optional[torch.Tensor] = None,
1731
+ start_positions: Optional[torch.Tensor] = None,
1732
+ end_positions: Optional[torch.Tensor] = None,
1733
+ output_attentions: Optional[bool] = None,
1734
+ output_hidden_states: Optional[bool] = None,
1735
+ return_dict: Optional[bool] = None,
1736
+ ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1737
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1738
+
1739
+ outputs = self.bert(
1740
+ input_ids,
1741
+ attention_mask=attention_mask,
1742
+ token_type_ids=token_type_ids,
1743
+ position_ids=position_ids,
1744
+ head_mask=head_mask,
1745
+ inputs_embeds=inputs_embeds,
1746
+ output_attentions=output_attentions,
1747
+ output_hidden_states=output_hidden_states,
1748
+ return_dict=return_dict,
1749
+ )
1750
+
1751
+ sequence_output = outputs[0]
1752
+
1753
+ logits = self.qa_outputs(sequence_output)
1754
+ start_logits, end_logits = logits.split(1, dim=-1)
1755
+ start_logits = start_logits.squeeze(-1).contiguous()
1756
+ end_logits = end_logits.squeeze(-1).contiguous()
1757
+
1758
+ total_loss = None
1759
+ if start_positions is not None and end_positions is not None:
1760
+ # If we are on multi-GPU, split add a dimension
1761
+ if len(start_positions.size()) > 1:
1762
+ start_positions = start_positions.squeeze(-1)
1763
+ if len(end_positions.size()) > 1:
1764
+ end_positions = end_positions.squeeze(-1)
1765
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1766
+ ignored_index = start_logits.size(1)
1767
+ start_positions = start_positions.clamp(0, ignored_index)
1768
+ end_positions = end_positions.clamp(0, ignored_index)
1769
+
1770
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1771
+ start_loss = loss_fct(start_logits, start_positions)
1772
+ end_loss = loss_fct(end_logits, end_positions)
1773
+ total_loss = (start_loss + end_loss) / 2
1774
+
1775
+ if not return_dict:
1776
+ output = (start_logits, end_logits) + outputs[2:]
1777
+ return ((total_loss,) + output) if total_loss is not None else output
1778
+
1779
+ return QuestionAnsweringModelOutput(
1780
+ loss=total_loss,
1781
+ start_logits=start_logits,
1782
+ end_logits=end_logits,
1783
+ hidden_states=outputs.hidden_states,
1784
+ attentions=outputs.attentions,
1785
+ )
1786
+
1787
+
1788
+ __all__ = [
1789
+ "BertForMaskedLM",
1790
+ "BertForMultipleChoice",
1791
+ "BertForNextSentencePrediction",
1792
+ "BertForPreTraining",
1793
+ "BertForQuestionAnswering",
1794
+ "BertForSequenceClassification",
1795
+ "BertForTokenClassification",
1796
+ "BertLayer",
1797
+ "BertLMHeadModel",
1798
+ "BertModel",
1799
+ "BertPreTrainedModel",
1800
+ "load_tf_weights_in_bert",
1801
+ ]
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_flax_bert.py ADDED
@@ -0,0 +1,1727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Callable, Optional
17
+
18
+ import flax
19
+ import flax.linen as nn
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
24
+ from flax.linen import combine_masks, make_causal_mask
25
+ from flax.linen import partitioning as nn_partitioning
26
+ from flax.linen.attention import dot_product_attention_weights
27
+ from flax.traverse_util import flatten_dict, unflatten_dict
28
+ from jax import lax
29
+
30
+ from ...modeling_flax_outputs import (
31
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
32
+ FlaxBaseModelOutputWithPooling,
33
+ FlaxBaseModelOutputWithPoolingAndCrossAttentions,
34
+ FlaxCausalLMOutputWithCrossAttentions,
35
+ FlaxMaskedLMOutput,
36
+ FlaxMultipleChoiceModelOutput,
37
+ FlaxNextSentencePredictorOutput,
38
+ FlaxQuestionAnsweringModelOutput,
39
+ FlaxSequenceClassifierOutput,
40
+ FlaxTokenClassifierOutput,
41
+ )
42
+ from ...modeling_flax_utils import (
43
+ ACT2FN,
44
+ FlaxPreTrainedModel,
45
+ append_call_sample_docstring,
46
+ append_replace_return_docstrings,
47
+ overwrite_call_docstring,
48
+ )
49
+ from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
50
+ from .configuration_bert import BertConfig
51
+
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+ _CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
56
+ _CONFIG_FOR_DOC = "BertConfig"
57
+
58
+ remat = nn_partitioning.remat
59
+
60
+
61
+ @flax.struct.dataclass
62
+ class FlaxBertForPreTrainingOutput(ModelOutput):
63
+ """
64
+ Output type of [`BertForPreTraining`].
65
+
66
+ Args:
67
+ prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
68
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
69
+ seq_relationship_logits (`jnp.ndarray` of shape `(batch_size, 2)`):
70
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
71
+ before SoftMax).
72
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
73
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
74
+ `(batch_size, sequence_length, hidden_size)`.
75
+
76
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
77
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
78
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
79
+ sequence_length)`.
80
+
81
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
82
+ heads.
83
+ """
84
+
85
+ prediction_logits: jnp.ndarray = None
86
+ seq_relationship_logits: jnp.ndarray = None
87
+ hidden_states: Optional[tuple[jnp.ndarray]] = None
88
+ attentions: Optional[tuple[jnp.ndarray]] = None
89
+
90
+
91
+ BERT_START_DOCSTRING = r"""
92
+
93
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
94
+ library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
95
+
96
+ This model is also a
97
+ [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
98
+ a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
99
+ behavior.
100
+
101
+ Finally, this model supports inherent JAX features such as:
102
+
103
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
104
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
105
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
106
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
107
+
108
+ Parameters:
109
+ config ([`BertConfig`]): Model configuration class with all the parameters of the model.
110
+ Initializing with a config file does not load the weights associated with the model, only the
111
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
112
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
113
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
114
+ `jax.numpy.bfloat16` (on TPUs).
115
+
116
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
117
+ specified all the computation will be performed with the given `dtype`.
118
+
119
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
120
+ parameters.**
121
+
122
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
123
+ [`~FlaxPreTrainedModel.to_bf16`].
124
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
125
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
126
+ `jax.numpy.bfloat16` (on TPUs).
127
+
128
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
129
+ specified all the computation will be performed with the given `dtype`.
130
+
131
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
132
+ parameters.**
133
+
134
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
135
+ [`~FlaxPreTrainedModel.to_bf16`].
136
+
137
+ """
138
+
139
+ BERT_INPUTS_DOCSTRING = r"""
140
+ Args:
141
+ input_ids (`numpy.ndarray` of shape `({0})`):
142
+ Indices of input sequence tokens in the vocabulary.
143
+
144
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
145
+ [`PreTrainedTokenizer.__call__`] for details.
146
+
147
+ [What are input IDs?](../glossary#input-ids)
148
+ attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
149
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
150
+
151
+ - 1 for tokens that are **not masked**,
152
+ - 0 for tokens that are **masked**.
153
+
154
+ [What are attention masks?](../glossary#attention-mask)
155
+ token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
156
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
157
+ 1]`:
158
+
159
+ - 0 corresponds to a *sentence A* token,
160
+ - 1 corresponds to a *sentence B* token.
161
+
162
+ [What are token type IDs?](../glossary#token-type-ids)
163
+ position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
164
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
165
+ config.max_position_embeddings - 1]`.
166
+ head_mask (`numpy.ndarray` of shape `({0})`, `optional):
167
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
168
+
169
+ - 1 indicates the head is **not masked**,
170
+ - 0 indicates the head is **masked**.
171
+
172
+ return_dict (`bool`, *optional*):
173
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
174
+
175
+ """
176
+
177
+
178
+ class FlaxBertEmbeddings(nn.Module):
179
+ """Construct the embeddings from word, position and token_type embeddings."""
180
+
181
+ config: BertConfig
182
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
183
+
184
+ def setup(self):
185
+ self.word_embeddings = nn.Embed(
186
+ self.config.vocab_size,
187
+ self.config.hidden_size,
188
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
189
+ dtype=self.dtype,
190
+ )
191
+ self.position_embeddings = nn.Embed(
192
+ self.config.max_position_embeddings,
193
+ self.config.hidden_size,
194
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
195
+ dtype=self.dtype,
196
+ )
197
+ self.token_type_embeddings = nn.Embed(
198
+ self.config.type_vocab_size,
199
+ self.config.hidden_size,
200
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
201
+ dtype=self.dtype,
202
+ )
203
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
204
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
205
+
206
+ def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
207
+ # Embed
208
+ inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
209
+ position_embeds = self.position_embeddings(position_ids.astype("i4"))
210
+ token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
211
+
212
+ # Sum all embeddings
213
+ hidden_states = inputs_embeds + token_type_embeddings + position_embeds
214
+
215
+ # Layer Norm
216
+ hidden_states = self.LayerNorm(hidden_states)
217
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
218
+ return hidden_states
219
+
220
+
221
+ class FlaxBertSelfAttention(nn.Module):
222
+ config: BertConfig
223
+ causal: bool = False
224
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
225
+
226
+ def setup(self):
227
+ self.head_dim = self.config.hidden_size // self.config.num_attention_heads
228
+ if self.config.hidden_size % self.config.num_attention_heads != 0:
229
+ raise ValueError(
230
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
231
+ " : {self.config.num_attention_heads}"
232
+ )
233
+
234
+ self.query = nn.Dense(
235
+ self.config.hidden_size,
236
+ dtype=self.dtype,
237
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
238
+ )
239
+ self.key = nn.Dense(
240
+ self.config.hidden_size,
241
+ dtype=self.dtype,
242
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
243
+ )
244
+ self.value = nn.Dense(
245
+ self.config.hidden_size,
246
+ dtype=self.dtype,
247
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
248
+ )
249
+
250
+ if self.causal:
251
+ self.causal_mask = make_causal_mask(
252
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
253
+ )
254
+
255
+ def _split_heads(self, hidden_states):
256
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
257
+
258
+ def _merge_heads(self, hidden_states):
259
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
260
+
261
+ @nn.compact
262
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
263
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
264
+ """
265
+ This function takes projected key, value states from a single input token and concatenates the states to cached
266
+ states from previous steps. This function is slightly adapted from the official Flax repository:
267
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
268
+ """
269
+ # detect if we're initializing by absence of existing cache data.
270
+ is_initialized = self.has_variable("cache", "cached_key")
271
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
272
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
273
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
274
+
275
+ if is_initialized:
276
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
277
+ # update key, value caches with our new 1d spatial slices
278
+ cur_index = cache_index.value
279
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
280
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
281
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
282
+ cached_key.value = key
283
+ cached_value.value = value
284
+ num_updated_cache_vectors = query.shape[1]
285
+ cache_index.value = cache_index.value + num_updated_cache_vectors
286
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
287
+ pad_mask = jnp.broadcast_to(
288
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
289
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
290
+ )
291
+ attention_mask = combine_masks(pad_mask, attention_mask)
292
+ return key, value, attention_mask
293
+
294
+ def __call__(
295
+ self,
296
+ hidden_states,
297
+ attention_mask,
298
+ layer_head_mask,
299
+ key_value_states: Optional[jnp.ndarray] = None,
300
+ init_cache: bool = False,
301
+ deterministic=True,
302
+ output_attentions: bool = False,
303
+ ):
304
+ # if key_value_states are provided this layer is used as a cross-attention layer
305
+ # for the decoder
306
+ is_cross_attention = key_value_states is not None
307
+ batch_size = hidden_states.shape[0]
308
+
309
+ # get query proj
310
+ query_states = self.query(hidden_states)
311
+ # get key, value proj
312
+ if is_cross_attention:
313
+ # cross_attentions
314
+ key_states = self.key(key_value_states)
315
+ value_states = self.value(key_value_states)
316
+ else:
317
+ # self_attention
318
+ key_states = self.key(hidden_states)
319
+ value_states = self.value(hidden_states)
320
+
321
+ query_states = self._split_heads(query_states)
322
+ key_states = self._split_heads(key_states)
323
+ value_states = self._split_heads(value_states)
324
+
325
+ # handle cache prepare causal attention mask
326
+ if self.causal:
327
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
328
+ if self.has_variable("cache", "cached_key"):
329
+ mask_shift = self.variables["cache"]["cache_index"]
330
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
331
+ causal_mask = lax.dynamic_slice(
332
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
333
+ )
334
+ else:
335
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
336
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
337
+
338
+ # combine masks if needed
339
+ if attention_mask is not None and self.causal:
340
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
341
+ attention_mask = combine_masks(attention_mask, causal_mask)
342
+ elif self.causal:
343
+ attention_mask = causal_mask
344
+ elif attention_mask is not None:
345
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
346
+
347
+ # During fast autoregressive decoding, we feed one position at a time,
348
+ # and cache the keys and values step by step.
349
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
350
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
351
+ key_states, value_states, query_states, attention_mask
352
+ )
353
+
354
+ # Convert the boolean attention mask to an attention bias.
355
+ if attention_mask is not None:
356
+ # attention mask in the form of attention bias
357
+ attention_bias = lax.select(
358
+ attention_mask > 0,
359
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
360
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
361
+ )
362
+ else:
363
+ attention_bias = None
364
+
365
+ dropout_rng = None
366
+ if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
367
+ dropout_rng = self.make_rng("dropout")
368
+
369
+ attn_weights = dot_product_attention_weights(
370
+ query_states,
371
+ key_states,
372
+ bias=attention_bias,
373
+ dropout_rng=dropout_rng,
374
+ dropout_rate=self.config.attention_probs_dropout_prob,
375
+ broadcast_dropout=True,
376
+ deterministic=deterministic,
377
+ dtype=self.dtype,
378
+ precision=None,
379
+ )
380
+
381
+ # Mask heads if we want to
382
+ if layer_head_mask is not None:
383
+ attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
384
+
385
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
386
+ attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
387
+
388
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
389
+ return outputs
390
+
391
+
392
+ class FlaxBertSelfOutput(nn.Module):
393
+ config: BertConfig
394
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
395
+
396
+ def setup(self):
397
+ self.dense = nn.Dense(
398
+ self.config.hidden_size,
399
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
400
+ dtype=self.dtype,
401
+ )
402
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
403
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
404
+
405
+ def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
406
+ hidden_states = self.dense(hidden_states)
407
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
408
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
409
+ return hidden_states
410
+
411
+
412
+ class FlaxBertAttention(nn.Module):
413
+ config: BertConfig
414
+ causal: bool = False
415
+ dtype: jnp.dtype = jnp.float32
416
+
417
+ def setup(self):
418
+ self.self = FlaxBertSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
419
+ self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
420
+
421
+ def __call__(
422
+ self,
423
+ hidden_states,
424
+ attention_mask,
425
+ layer_head_mask,
426
+ key_value_states=None,
427
+ init_cache=False,
428
+ deterministic=True,
429
+ output_attentions: bool = False,
430
+ ):
431
+ # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
432
+ # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
433
+ # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
434
+ attn_outputs = self.self(
435
+ hidden_states,
436
+ attention_mask,
437
+ layer_head_mask=layer_head_mask,
438
+ key_value_states=key_value_states,
439
+ init_cache=init_cache,
440
+ deterministic=deterministic,
441
+ output_attentions=output_attentions,
442
+ )
443
+ attn_output = attn_outputs[0]
444
+ hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
445
+
446
+ outputs = (hidden_states,)
447
+
448
+ if output_attentions:
449
+ outputs += (attn_outputs[1],)
450
+
451
+ return outputs
452
+
453
+
454
+ class FlaxBertIntermediate(nn.Module):
455
+ config: BertConfig
456
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
457
+
458
+ def setup(self):
459
+ self.dense = nn.Dense(
460
+ self.config.intermediate_size,
461
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
462
+ dtype=self.dtype,
463
+ )
464
+ self.activation = ACT2FN[self.config.hidden_act]
465
+
466
+ def __call__(self, hidden_states):
467
+ hidden_states = self.dense(hidden_states)
468
+ hidden_states = self.activation(hidden_states)
469
+ return hidden_states
470
+
471
+
472
+ class FlaxBertOutput(nn.Module):
473
+ config: BertConfig
474
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
475
+
476
+ def setup(self):
477
+ self.dense = nn.Dense(
478
+ self.config.hidden_size,
479
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
480
+ dtype=self.dtype,
481
+ )
482
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
483
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
484
+
485
+ def __call__(self, hidden_states, attention_output, deterministic: bool = True):
486
+ hidden_states = self.dense(hidden_states)
487
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
488
+ hidden_states = self.LayerNorm(hidden_states + attention_output)
489
+ return hidden_states
490
+
491
+
492
+ class FlaxBertLayer(nn.Module):
493
+ config: BertConfig
494
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
495
+
496
+ def setup(self):
497
+ self.attention = FlaxBertAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
498
+ self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
499
+ self.output = FlaxBertOutput(self.config, dtype=self.dtype)
500
+ if self.config.add_cross_attention:
501
+ self.crossattention = FlaxBertAttention(self.config, causal=False, dtype=self.dtype)
502
+
503
+ def __call__(
504
+ self,
505
+ hidden_states,
506
+ attention_mask,
507
+ layer_head_mask,
508
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
509
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
510
+ init_cache: bool = False,
511
+ deterministic: bool = True,
512
+ output_attentions: bool = False,
513
+ ):
514
+ # Self Attention
515
+ attention_outputs = self.attention(
516
+ hidden_states,
517
+ attention_mask,
518
+ layer_head_mask=layer_head_mask,
519
+ init_cache=init_cache,
520
+ deterministic=deterministic,
521
+ output_attentions=output_attentions,
522
+ )
523
+ attention_output = attention_outputs[0]
524
+
525
+ # Cross-Attention Block
526
+ if encoder_hidden_states is not None:
527
+ cross_attention_outputs = self.crossattention(
528
+ attention_output,
529
+ attention_mask=encoder_attention_mask,
530
+ layer_head_mask=layer_head_mask,
531
+ key_value_states=encoder_hidden_states,
532
+ deterministic=deterministic,
533
+ output_attentions=output_attentions,
534
+ )
535
+ attention_output = cross_attention_outputs[0]
536
+
537
+ hidden_states = self.intermediate(attention_output)
538
+ hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
539
+
540
+ outputs = (hidden_states,)
541
+
542
+ if output_attentions:
543
+ outputs += (attention_outputs[1],)
544
+ if encoder_hidden_states is not None:
545
+ outputs += (cross_attention_outputs[1],)
546
+ return outputs
547
+
548
+
549
+ class FlaxBertLayerCollection(nn.Module):
550
+ config: BertConfig
551
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
552
+ gradient_checkpointing: bool = False
553
+
554
+ def setup(self):
555
+ if self.gradient_checkpointing:
556
+ FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7))
557
+ self.layers = [
558
+ FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
559
+ for i in range(self.config.num_hidden_layers)
560
+ ]
561
+ else:
562
+ self.layers = [
563
+ FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
564
+ ]
565
+
566
+ def __call__(
567
+ self,
568
+ hidden_states,
569
+ attention_mask,
570
+ head_mask,
571
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
572
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
573
+ init_cache: bool = False,
574
+ deterministic: bool = True,
575
+ output_attentions: bool = False,
576
+ output_hidden_states: bool = False,
577
+ return_dict: bool = True,
578
+ ):
579
+ all_attentions = () if output_attentions else None
580
+ all_hidden_states = () if output_hidden_states else None
581
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
582
+
583
+ # Check if head_mask has a correct number of layers specified if desired
584
+ if head_mask is not None:
585
+ if head_mask.shape[0] != (len(self.layers)):
586
+ raise ValueError(
587
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
588
+ f" {head_mask.shape[0]}."
589
+ )
590
+
591
+ for i, layer in enumerate(self.layers):
592
+ if output_hidden_states:
593
+ all_hidden_states += (hidden_states,)
594
+
595
+ layer_outputs = layer(
596
+ hidden_states,
597
+ attention_mask,
598
+ head_mask[i] if head_mask is not None else None,
599
+ encoder_hidden_states,
600
+ encoder_attention_mask,
601
+ init_cache,
602
+ deterministic,
603
+ output_attentions,
604
+ )
605
+
606
+ hidden_states = layer_outputs[0]
607
+
608
+ if output_attentions:
609
+ all_attentions += (layer_outputs[1],)
610
+
611
+ if encoder_hidden_states is not None:
612
+ all_cross_attentions += (layer_outputs[2],)
613
+
614
+ if output_hidden_states:
615
+ all_hidden_states += (hidden_states,)
616
+
617
+ outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
618
+
619
+ if not return_dict:
620
+ return tuple(v for v in outputs if v is not None)
621
+
622
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
623
+ last_hidden_state=hidden_states,
624
+ hidden_states=all_hidden_states,
625
+ attentions=all_attentions,
626
+ cross_attentions=all_cross_attentions,
627
+ )
628
+
629
+
630
+ class FlaxBertEncoder(nn.Module):
631
+ config: BertConfig
632
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
633
+ gradient_checkpointing: bool = False
634
+
635
+ def setup(self):
636
+ self.layer = FlaxBertLayerCollection(
637
+ self.config,
638
+ dtype=self.dtype,
639
+ gradient_checkpointing=self.gradient_checkpointing,
640
+ )
641
+
642
+ def __call__(
643
+ self,
644
+ hidden_states,
645
+ attention_mask,
646
+ head_mask,
647
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
648
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
649
+ init_cache: bool = False,
650
+ deterministic: bool = True,
651
+ output_attentions: bool = False,
652
+ output_hidden_states: bool = False,
653
+ return_dict: bool = True,
654
+ ):
655
+ return self.layer(
656
+ hidden_states,
657
+ attention_mask,
658
+ head_mask=head_mask,
659
+ encoder_hidden_states=encoder_hidden_states,
660
+ encoder_attention_mask=encoder_attention_mask,
661
+ init_cache=init_cache,
662
+ deterministic=deterministic,
663
+ output_attentions=output_attentions,
664
+ output_hidden_states=output_hidden_states,
665
+ return_dict=return_dict,
666
+ )
667
+
668
+
669
+ class FlaxBertPooler(nn.Module):
670
+ config: BertConfig
671
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
672
+
673
+ def setup(self):
674
+ self.dense = nn.Dense(
675
+ self.config.hidden_size,
676
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
677
+ dtype=self.dtype,
678
+ )
679
+
680
+ def __call__(self, hidden_states):
681
+ cls_hidden_state = hidden_states[:, 0]
682
+ cls_hidden_state = self.dense(cls_hidden_state)
683
+ return nn.tanh(cls_hidden_state)
684
+
685
+
686
+ class FlaxBertPredictionHeadTransform(nn.Module):
687
+ config: BertConfig
688
+ dtype: jnp.dtype = jnp.float32
689
+
690
+ def setup(self):
691
+ self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
692
+ self.activation = ACT2FN[self.config.hidden_act]
693
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
694
+
695
+ def __call__(self, hidden_states):
696
+ hidden_states = self.dense(hidden_states)
697
+ hidden_states = self.activation(hidden_states)
698
+ return self.LayerNorm(hidden_states)
699
+
700
+
701
+ class FlaxBertLMPredictionHead(nn.Module):
702
+ config: BertConfig
703
+ dtype: jnp.dtype = jnp.float32
704
+ bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
705
+
706
+ def setup(self):
707
+ self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype)
708
+ self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
709
+ self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
710
+
711
+ def __call__(self, hidden_states, shared_embedding=None):
712
+ hidden_states = self.transform(hidden_states)
713
+
714
+ if shared_embedding is not None:
715
+ hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
716
+ else:
717
+ hidden_states = self.decoder(hidden_states)
718
+
719
+ bias = jnp.asarray(self.bias, self.dtype)
720
+ hidden_states += bias
721
+ return hidden_states
722
+
723
+
724
+ class FlaxBertOnlyMLMHead(nn.Module):
725
+ config: BertConfig
726
+ dtype: jnp.dtype = jnp.float32
727
+
728
+ def setup(self):
729
+ self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
730
+
731
+ def __call__(self, hidden_states, shared_embedding=None):
732
+ hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding)
733
+ return hidden_states
734
+
735
+
736
+ class FlaxBertOnlyNSPHead(nn.Module):
737
+ dtype: jnp.dtype = jnp.float32
738
+
739
+ def setup(self):
740
+ self.seq_relationship = nn.Dense(2, dtype=self.dtype)
741
+
742
+ def __call__(self, pooled_output):
743
+ return self.seq_relationship(pooled_output)
744
+
745
+
746
+ class FlaxBertPreTrainingHeads(nn.Module):
747
+ config: BertConfig
748
+ dtype: jnp.dtype = jnp.float32
749
+
750
+ def setup(self):
751
+ self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
752
+ self.seq_relationship = nn.Dense(2, dtype=self.dtype)
753
+
754
+ def __call__(self, hidden_states, pooled_output, shared_embedding=None):
755
+ prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
756
+ seq_relationship_score = self.seq_relationship(pooled_output)
757
+ return prediction_scores, seq_relationship_score
758
+
759
+
760
+ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
761
+ """
762
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
763
+ models.
764
+ """
765
+
766
+ config_class = BertConfig
767
+ base_model_prefix = "bert"
768
+ module_class: nn.Module = None
769
+
770
+ def __init__(
771
+ self,
772
+ config: BertConfig,
773
+ input_shape: tuple = (1, 1),
774
+ seed: int = 0,
775
+ dtype: jnp.dtype = jnp.float32,
776
+ _do_init: bool = True,
777
+ gradient_checkpointing: bool = False,
778
+ **kwargs,
779
+ ):
780
+ module = self.module_class(
781
+ config=config,
782
+ dtype=dtype,
783
+ gradient_checkpointing=gradient_checkpointing,
784
+ **kwargs,
785
+ )
786
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
787
+
788
+ def enable_gradient_checkpointing(self):
789
+ self._module = self.module_class(
790
+ config=self.config,
791
+ dtype=self.dtype,
792
+ gradient_checkpointing=True,
793
+ )
794
+
795
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict:
796
+ # init input tensors
797
+ input_ids = jnp.zeros(input_shape, dtype="i4")
798
+ token_type_ids = jnp.zeros_like(input_ids)
799
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
800
+ attention_mask = jnp.ones_like(input_ids)
801
+ head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
802
+
803
+ params_rng, dropout_rng = jax.random.split(rng)
804
+ rngs = {"params": params_rng, "dropout": dropout_rng}
805
+
806
+ if self.config.add_cross_attention:
807
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
808
+ encoder_attention_mask = attention_mask
809
+ module_init_outputs = self.module.init(
810
+ rngs,
811
+ input_ids,
812
+ attention_mask,
813
+ token_type_ids,
814
+ position_ids,
815
+ head_mask,
816
+ encoder_hidden_states,
817
+ encoder_attention_mask,
818
+ return_dict=False,
819
+ )
820
+ else:
821
+ module_init_outputs = self.module.init(
822
+ rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
823
+ )
824
+
825
+ random_params = module_init_outputs["params"]
826
+
827
+ if params is not None:
828
+ random_params = flatten_dict(unfreeze(random_params))
829
+ params = flatten_dict(unfreeze(params))
830
+ for missing_key in self._missing_keys:
831
+ params[missing_key] = random_params[missing_key]
832
+ self._missing_keys = set()
833
+ return freeze(unflatten_dict(params))
834
+ else:
835
+ return random_params
836
+
837
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
838
+ def init_cache(self, batch_size, max_length):
839
+ r"""
840
+ Args:
841
+ batch_size (`int`):
842
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
843
+ max_length (`int`):
844
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
845
+ cache.
846
+ """
847
+ # init input variables to retrieve cache
848
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
849
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
850
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
851
+
852
+ init_variables = self.module.init(
853
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
854
+ )
855
+ return unfreeze(init_variables["cache"])
856
+
857
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
858
+ def __call__(
859
+ self,
860
+ input_ids,
861
+ attention_mask=None,
862
+ token_type_ids=None,
863
+ position_ids=None,
864
+ head_mask=None,
865
+ encoder_hidden_states=None,
866
+ encoder_attention_mask=None,
867
+ params: Optional[dict] = None,
868
+ dropout_rng: jax.random.PRNGKey = None,
869
+ train: bool = False,
870
+ output_attentions: Optional[bool] = None,
871
+ output_hidden_states: Optional[bool] = None,
872
+ return_dict: Optional[bool] = None,
873
+ past_key_values: Optional[dict] = None,
874
+ ):
875
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
876
+ output_hidden_states = (
877
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
878
+ )
879
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
880
+
881
+ # init input tensors if not passed
882
+ if token_type_ids is None:
883
+ token_type_ids = jnp.zeros_like(input_ids)
884
+
885
+ if position_ids is None:
886
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
887
+
888
+ if attention_mask is None:
889
+ attention_mask = jnp.ones_like(input_ids)
890
+
891
+ if head_mask is None:
892
+ head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
893
+
894
+ # Handle any PRNG if needed
895
+ rngs = {}
896
+ if dropout_rng is not None:
897
+ rngs["dropout"] = dropout_rng
898
+
899
+ inputs = {"params": params or self.params}
900
+
901
+ if self.config.add_cross_attention:
902
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
903
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
904
+ # changed by FlaxBertAttention module
905
+ if past_key_values:
906
+ inputs["cache"] = past_key_values
907
+ mutable = ["cache"]
908
+ else:
909
+ mutable = False
910
+
911
+ outputs = self.module.apply(
912
+ inputs,
913
+ jnp.array(input_ids, dtype="i4"),
914
+ jnp.array(attention_mask, dtype="i4"),
915
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
916
+ position_ids=jnp.array(position_ids, dtype="i4"),
917
+ head_mask=jnp.array(head_mask, dtype="i4"),
918
+ encoder_hidden_states=encoder_hidden_states,
919
+ encoder_attention_mask=encoder_attention_mask,
920
+ deterministic=not train,
921
+ output_attentions=output_attentions,
922
+ output_hidden_states=output_hidden_states,
923
+ return_dict=return_dict,
924
+ rngs=rngs,
925
+ mutable=mutable,
926
+ )
927
+
928
+ # add updated cache to model output
929
+ if past_key_values is not None and return_dict:
930
+ outputs, past_key_values = outputs
931
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
932
+ return outputs
933
+ elif past_key_values is not None and not return_dict:
934
+ outputs, past_key_values = outputs
935
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
936
+
937
+ else:
938
+ outputs = self.module.apply(
939
+ inputs,
940
+ jnp.array(input_ids, dtype="i4"),
941
+ jnp.array(attention_mask, dtype="i4"),
942
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
943
+ position_ids=jnp.array(position_ids, dtype="i4"),
944
+ head_mask=jnp.array(head_mask, dtype="i4"),
945
+ deterministic=not train,
946
+ output_attentions=output_attentions,
947
+ output_hidden_states=output_hidden_states,
948
+ return_dict=return_dict,
949
+ rngs=rngs,
950
+ )
951
+
952
+ return outputs
953
+
954
+
955
+ class FlaxBertModule(nn.Module):
956
+ config: BertConfig
957
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
958
+ add_pooling_layer: bool = True
959
+ gradient_checkpointing: bool = False
960
+
961
+ def setup(self):
962
+ self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)
963
+ self.encoder = FlaxBertEncoder(
964
+ self.config,
965
+ dtype=self.dtype,
966
+ gradient_checkpointing=self.gradient_checkpointing,
967
+ )
968
+ self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
969
+
970
+ def __call__(
971
+ self,
972
+ input_ids,
973
+ attention_mask,
974
+ token_type_ids: Optional[jnp.ndarray] = None,
975
+ position_ids: Optional[jnp.ndarray] = None,
976
+ head_mask: Optional[jnp.ndarray] = None,
977
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
978
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
979
+ init_cache: bool = False,
980
+ deterministic: bool = True,
981
+ output_attentions: bool = False,
982
+ output_hidden_states: bool = False,
983
+ return_dict: bool = True,
984
+ ):
985
+ # make sure `token_type_ids` is correctly initialized when not passed
986
+ if token_type_ids is None:
987
+ token_type_ids = jnp.zeros_like(input_ids)
988
+
989
+ # make sure `position_ids` is correctly initialized when not passed
990
+ if position_ids is None:
991
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
992
+
993
+ hidden_states = self.embeddings(
994
+ input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
995
+ )
996
+ outputs = self.encoder(
997
+ hidden_states,
998
+ attention_mask,
999
+ head_mask=head_mask,
1000
+ deterministic=deterministic,
1001
+ encoder_hidden_states=encoder_hidden_states,
1002
+ encoder_attention_mask=encoder_attention_mask,
1003
+ init_cache=init_cache,
1004
+ output_attentions=output_attentions,
1005
+ output_hidden_states=output_hidden_states,
1006
+ return_dict=return_dict,
1007
+ )
1008
+ hidden_states = outputs[0]
1009
+ pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
1010
+
1011
+ if not return_dict:
1012
+ # if pooled is None, don't return it
1013
+ if pooled is None:
1014
+ return (hidden_states,) + outputs[1:]
1015
+ return (hidden_states, pooled) + outputs[1:]
1016
+
1017
+ return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
1018
+ last_hidden_state=hidden_states,
1019
+ pooler_output=pooled,
1020
+ hidden_states=outputs.hidden_states,
1021
+ attentions=outputs.attentions,
1022
+ cross_attentions=outputs.cross_attentions,
1023
+ )
1024
+
1025
+
1026
+ @add_start_docstrings(
1027
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
1028
+ BERT_START_DOCSTRING,
1029
+ )
1030
+ class FlaxBertModel(FlaxBertPreTrainedModel):
1031
+ module_class = FlaxBertModule
1032
+
1033
+
1034
+ append_call_sample_docstring(FlaxBertModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)
1035
+
1036
+
1037
+ class FlaxBertForPreTrainingModule(nn.Module):
1038
+ config: BertConfig
1039
+ dtype: jnp.dtype = jnp.float32
1040
+ gradient_checkpointing: bool = False
1041
+
1042
+ def setup(self):
1043
+ self.bert = FlaxBertModule(
1044
+ config=self.config,
1045
+ dtype=self.dtype,
1046
+ gradient_checkpointing=self.gradient_checkpointing,
1047
+ )
1048
+ self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)
1049
+
1050
+ def __call__(
1051
+ self,
1052
+ input_ids,
1053
+ attention_mask,
1054
+ token_type_ids,
1055
+ position_ids,
1056
+ head_mask,
1057
+ deterministic: bool = True,
1058
+ output_attentions: bool = False,
1059
+ output_hidden_states: bool = False,
1060
+ return_dict: bool = True,
1061
+ ):
1062
+ # Model
1063
+ outputs = self.bert(
1064
+ input_ids,
1065
+ attention_mask,
1066
+ token_type_ids,
1067
+ position_ids,
1068
+ head_mask,
1069
+ deterministic=deterministic,
1070
+ output_attentions=output_attentions,
1071
+ output_hidden_states=output_hidden_states,
1072
+ return_dict=return_dict,
1073
+ )
1074
+
1075
+ if self.config.tie_word_embeddings:
1076
+ shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
1077
+ else:
1078
+ shared_embedding = None
1079
+
1080
+ hidden_states = outputs[0]
1081
+ pooled_output = outputs[1]
1082
+
1083
+ prediction_scores, seq_relationship_score = self.cls(
1084
+ hidden_states, pooled_output, shared_embedding=shared_embedding
1085
+ )
1086
+
1087
+ if not return_dict:
1088
+ return (prediction_scores, seq_relationship_score) + outputs[2:]
1089
+
1090
+ return FlaxBertForPreTrainingOutput(
1091
+ prediction_logits=prediction_scores,
1092
+ seq_relationship_logits=seq_relationship_score,
1093
+ hidden_states=outputs.hidden_states,
1094
+ attentions=outputs.attentions,
1095
+ )
1096
+
1097
+
1098
+ @add_start_docstrings(
1099
+ """
1100
+ Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
1101
+ sentence prediction (classification)` head.
1102
+ """,
1103
+ BERT_START_DOCSTRING,
1104
+ )
1105
+ class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
1106
+ module_class = FlaxBertForPreTrainingModule
1107
+
1108
+
1109
+ FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """
1110
+ Returns:
1111
+
1112
+ Example:
1113
+
1114
+ ```python
1115
+ >>> from transformers import AutoTokenizer, FlaxBertForPreTraining
1116
+
1117
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
1118
+ >>> model = FlaxBertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
1119
+
1120
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
1121
+ >>> outputs = model(**inputs)
1122
+
1123
+ >>> prediction_logits = outputs.prediction_logits
1124
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
1125
+ ```
1126
+ """
1127
+
1128
+ overwrite_call_docstring(
1129
+ FlaxBertForPreTraining,
1130
+ BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_PRETRAINING_DOCSTRING,
1131
+ )
1132
+ append_replace_return_docstrings(
1133
+ FlaxBertForPreTraining, output_type=FlaxBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
1134
+ )
1135
+
1136
+
1137
+ class FlaxBertForMaskedLMModule(nn.Module):
1138
+ config: BertConfig
1139
+ dtype: jnp.dtype = jnp.float32
1140
+ gradient_checkpointing: bool = False
1141
+
1142
+ def setup(self):
1143
+ self.bert = FlaxBertModule(
1144
+ config=self.config,
1145
+ add_pooling_layer=False,
1146
+ dtype=self.dtype,
1147
+ gradient_checkpointing=self.gradient_checkpointing,
1148
+ )
1149
+ self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
1150
+
1151
+ def __call__(
1152
+ self,
1153
+ input_ids,
1154
+ attention_mask,
1155
+ token_type_ids,
1156
+ position_ids,
1157
+ head_mask,
1158
+ deterministic: bool = True,
1159
+ output_attentions: bool = False,
1160
+ output_hidden_states: bool = False,
1161
+ return_dict: bool = True,
1162
+ ):
1163
+ # Model
1164
+ outputs = self.bert(
1165
+ input_ids,
1166
+ attention_mask,
1167
+ token_type_ids,
1168
+ position_ids,
1169
+ head_mask,
1170
+ deterministic=deterministic,
1171
+ output_attentions=output_attentions,
1172
+ output_hidden_states=output_hidden_states,
1173
+ return_dict=return_dict,
1174
+ )
1175
+
1176
+ hidden_states = outputs[0]
1177
+ if self.config.tie_word_embeddings:
1178
+ shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
1179
+ else:
1180
+ shared_embedding = None
1181
+
1182
+ # Compute the prediction scores
1183
+ logits = self.cls(hidden_states, shared_embedding=shared_embedding)
1184
+
1185
+ if not return_dict:
1186
+ return (logits,) + outputs[1:]
1187
+
1188
+ return FlaxMaskedLMOutput(
1189
+ logits=logits,
1190
+ hidden_states=outputs.hidden_states,
1191
+ attentions=outputs.attentions,
1192
+ )
1193
+
1194
+
1195
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
1196
+ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
1197
+ module_class = FlaxBertForMaskedLMModule
1198
+
1199
+
1200
+ append_call_sample_docstring(FlaxBertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC)
1201
+
1202
+
1203
+ class FlaxBertForNextSentencePredictionModule(nn.Module):
1204
+ config: BertConfig
1205
+ dtype: jnp.dtype = jnp.float32
1206
+ gradient_checkpointing: bool = False
1207
+
1208
+ def setup(self):
1209
+ self.bert = FlaxBertModule(
1210
+ config=self.config,
1211
+ dtype=self.dtype,
1212
+ gradient_checkpointing=self.gradient_checkpointing,
1213
+ )
1214
+ self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)
1215
+
1216
+ def __call__(
1217
+ self,
1218
+ input_ids,
1219
+ attention_mask,
1220
+ token_type_ids,
1221
+ position_ids,
1222
+ head_mask,
1223
+ deterministic: bool = True,
1224
+ output_attentions: bool = False,
1225
+ output_hidden_states: bool = False,
1226
+ return_dict: bool = True,
1227
+ ):
1228
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1229
+
1230
+ # Model
1231
+ outputs = self.bert(
1232
+ input_ids,
1233
+ attention_mask,
1234
+ token_type_ids,
1235
+ position_ids,
1236
+ head_mask,
1237
+ deterministic=deterministic,
1238
+ output_attentions=output_attentions,
1239
+ output_hidden_states=output_hidden_states,
1240
+ return_dict=return_dict,
1241
+ )
1242
+
1243
+ pooled_output = outputs[1]
1244
+ seq_relationship_scores = self.cls(pooled_output)
1245
+
1246
+ if not return_dict:
1247
+ return (seq_relationship_scores,) + outputs[2:]
1248
+
1249
+ return FlaxNextSentencePredictorOutput(
1250
+ logits=seq_relationship_scores,
1251
+ hidden_states=outputs.hidden_states,
1252
+ attentions=outputs.attentions,
1253
+ )
1254
+
1255
+
1256
+ @add_start_docstrings(
1257
+ """Bert Model with a `next sentence prediction (classification)` head on top.""",
1258
+ BERT_START_DOCSTRING,
1259
+ )
1260
+ class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
1261
+ module_class = FlaxBertForNextSentencePredictionModule
1262
+
1263
+
1264
+ FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING = """
1265
+ Returns:
1266
+
1267
+ Example:
1268
+
1269
+ ```python
1270
+ >>> from transformers import AutoTokenizer, FlaxBertForNextSentencePrediction
1271
+
1272
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
1273
+ >>> model = FlaxBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
1274
+
1275
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1276
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1277
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="jax")
1278
+
1279
+ >>> outputs = model(**encoding)
1280
+ >>> logits = outputs.logits
1281
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1282
+ ```
1283
+ """
1284
+
1285
+
1286
+ overwrite_call_docstring(
1287
+ FlaxBertForNextSentencePrediction,
1288
+ BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING,
1289
+ )
1290
+ append_replace_return_docstrings(
1291
+ FlaxBertForNextSentencePrediction, output_type=FlaxNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
1292
+ )
1293
+
1294
+
1295
+ class FlaxBertForSequenceClassificationModule(nn.Module):
1296
+ config: BertConfig
1297
+ dtype: jnp.dtype = jnp.float32
1298
+ gradient_checkpointing: bool = False
1299
+
1300
+ def setup(self):
1301
+ self.bert = FlaxBertModule(
1302
+ config=self.config,
1303
+ dtype=self.dtype,
1304
+ gradient_checkpointing=self.gradient_checkpointing,
1305
+ )
1306
+ classifier_dropout = (
1307
+ self.config.classifier_dropout
1308
+ if self.config.classifier_dropout is not None
1309
+ else self.config.hidden_dropout_prob
1310
+ )
1311
+ self.dropout = nn.Dropout(rate=classifier_dropout)
1312
+ self.classifier = nn.Dense(
1313
+ self.config.num_labels,
1314
+ dtype=self.dtype,
1315
+ )
1316
+
1317
+ def __call__(
1318
+ self,
1319
+ input_ids,
1320
+ attention_mask,
1321
+ token_type_ids,
1322
+ position_ids,
1323
+ head_mask,
1324
+ deterministic: bool = True,
1325
+ output_attentions: bool = False,
1326
+ output_hidden_states: bool = False,
1327
+ return_dict: bool = True,
1328
+ ):
1329
+ # Model
1330
+ outputs = self.bert(
1331
+ input_ids,
1332
+ attention_mask,
1333
+ token_type_ids,
1334
+ position_ids,
1335
+ head_mask,
1336
+ deterministic=deterministic,
1337
+ output_attentions=output_attentions,
1338
+ output_hidden_states=output_hidden_states,
1339
+ return_dict=return_dict,
1340
+ )
1341
+
1342
+ pooled_output = outputs[1]
1343
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
1344
+ logits = self.classifier(pooled_output)
1345
+
1346
+ if not return_dict:
1347
+ return (logits,) + outputs[2:]
1348
+
1349
+ return FlaxSequenceClassifierOutput(
1350
+ logits=logits,
1351
+ hidden_states=outputs.hidden_states,
1352
+ attentions=outputs.attentions,
1353
+ )
1354
+
1355
+
1356
+ @add_start_docstrings(
1357
+ """
1358
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1359
+ output) e.g. for GLUE tasks.
1360
+ """,
1361
+ BERT_START_DOCSTRING,
1362
+ )
1363
+ class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
1364
+ module_class = FlaxBertForSequenceClassificationModule
1365
+
1366
+
1367
+ append_call_sample_docstring(
1368
+ FlaxBertForSequenceClassification,
1369
+ _CHECKPOINT_FOR_DOC,
1370
+ FlaxSequenceClassifierOutput,
1371
+ _CONFIG_FOR_DOC,
1372
+ )
1373
+
1374
+
1375
+ class FlaxBertForMultipleChoiceModule(nn.Module):
1376
+ config: BertConfig
1377
+ dtype: jnp.dtype = jnp.float32
1378
+ gradient_checkpointing: bool = False
1379
+
1380
+ def setup(self):
1381
+ self.bert = FlaxBertModule(
1382
+ config=self.config,
1383
+ dtype=self.dtype,
1384
+ gradient_checkpointing=self.gradient_checkpointing,
1385
+ )
1386
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
1387
+ self.classifier = nn.Dense(1, dtype=self.dtype)
1388
+
1389
+ def __call__(
1390
+ self,
1391
+ input_ids,
1392
+ attention_mask,
1393
+ token_type_ids,
1394
+ position_ids,
1395
+ head_mask,
1396
+ deterministic: bool = True,
1397
+ output_attentions: bool = False,
1398
+ output_hidden_states: bool = False,
1399
+ return_dict: bool = True,
1400
+ ):
1401
+ num_choices = input_ids.shape[1]
1402
+ input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
1403
+ attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
1404
+ token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
1405
+ position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
1406
+
1407
+ # Model
1408
+ outputs = self.bert(
1409
+ input_ids,
1410
+ attention_mask,
1411
+ token_type_ids,
1412
+ position_ids,
1413
+ head_mask,
1414
+ deterministic=deterministic,
1415
+ output_attentions=output_attentions,
1416
+ output_hidden_states=output_hidden_states,
1417
+ return_dict=return_dict,
1418
+ )
1419
+
1420
+ pooled_output = outputs[1]
1421
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
1422
+ logits = self.classifier(pooled_output)
1423
+
1424
+ reshaped_logits = logits.reshape(-1, num_choices)
1425
+
1426
+ if not return_dict:
1427
+ return (reshaped_logits,) + outputs[2:]
1428
+
1429
+ return FlaxMultipleChoiceModelOutput(
1430
+ logits=reshaped_logits,
1431
+ hidden_states=outputs.hidden_states,
1432
+ attentions=outputs.attentions,
1433
+ )
1434
+
1435
+
1436
+ @add_start_docstrings(
1437
+ """
1438
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1439
+ softmax) e.g. for RocStories/SWAG tasks.
1440
+ """,
1441
+ BERT_START_DOCSTRING,
1442
+ )
1443
+ class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
1444
+ module_class = FlaxBertForMultipleChoiceModule
1445
+
1446
+
1447
+ overwrite_call_docstring(
1448
+ FlaxBertForMultipleChoice, BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
1449
+ )
1450
+ append_call_sample_docstring(
1451
+ FlaxBertForMultipleChoice, _CHECKPOINT_FOR_DOC, FlaxMultipleChoiceModelOutput, _CONFIG_FOR_DOC
1452
+ )
1453
+
1454
+
1455
+ class FlaxBertForTokenClassificationModule(nn.Module):
1456
+ config: BertConfig
1457
+ dtype: jnp.dtype = jnp.float32
1458
+ gradient_checkpointing: bool = False
1459
+
1460
+ def setup(self):
1461
+ self.bert = FlaxBertModule(
1462
+ config=self.config,
1463
+ dtype=self.dtype,
1464
+ add_pooling_layer=False,
1465
+ gradient_checkpointing=self.gradient_checkpointing,
1466
+ )
1467
+ classifier_dropout = (
1468
+ self.config.classifier_dropout
1469
+ if self.config.classifier_dropout is not None
1470
+ else self.config.hidden_dropout_prob
1471
+ )
1472
+ self.dropout = nn.Dropout(rate=classifier_dropout)
1473
+ self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
1474
+
1475
+ def __call__(
1476
+ self,
1477
+ input_ids,
1478
+ attention_mask,
1479
+ token_type_ids,
1480
+ position_ids,
1481
+ head_mask,
1482
+ deterministic: bool = True,
1483
+ output_attentions: bool = False,
1484
+ output_hidden_states: bool = False,
1485
+ return_dict: bool = True,
1486
+ ):
1487
+ # Model
1488
+ outputs = self.bert(
1489
+ input_ids,
1490
+ attention_mask,
1491
+ token_type_ids,
1492
+ position_ids,
1493
+ head_mask,
1494
+ deterministic=deterministic,
1495
+ output_attentions=output_attentions,
1496
+ output_hidden_states=output_hidden_states,
1497
+ return_dict=return_dict,
1498
+ )
1499
+
1500
+ hidden_states = outputs[0]
1501
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
1502
+ logits = self.classifier(hidden_states)
1503
+
1504
+ if not return_dict:
1505
+ return (logits,) + outputs[1:]
1506
+
1507
+ return FlaxTokenClassifierOutput(
1508
+ logits=logits,
1509
+ hidden_states=outputs.hidden_states,
1510
+ attentions=outputs.attentions,
1511
+ )
1512
+
1513
+
1514
+ @add_start_docstrings(
1515
+ """
1516
+ Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1517
+ Named-Entity-Recognition (NER) tasks.
1518
+ """,
1519
+ BERT_START_DOCSTRING,
1520
+ )
1521
+ class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
1522
+ module_class = FlaxBertForTokenClassificationModule
1523
+
1524
+
1525
+ append_call_sample_docstring(
1526
+ FlaxBertForTokenClassification, _CHECKPOINT_FOR_DOC, FlaxTokenClassifierOutput, _CONFIG_FOR_DOC
1527
+ )
1528
+
1529
+
1530
+ class FlaxBertForQuestionAnsweringModule(nn.Module):
1531
+ config: BertConfig
1532
+ dtype: jnp.dtype = jnp.float32
1533
+ gradient_checkpointing: bool = False
1534
+
1535
+ def setup(self):
1536
+ self.bert = FlaxBertModule(
1537
+ config=self.config,
1538
+ dtype=self.dtype,
1539
+ add_pooling_layer=False,
1540
+ gradient_checkpointing=self.gradient_checkpointing,
1541
+ )
1542
+ self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
1543
+
1544
+ def __call__(
1545
+ self,
1546
+ input_ids,
1547
+ attention_mask,
1548
+ token_type_ids,
1549
+ position_ids,
1550
+ head_mask,
1551
+ deterministic: bool = True,
1552
+ output_attentions: bool = False,
1553
+ output_hidden_states: bool = False,
1554
+ return_dict: bool = True,
1555
+ ):
1556
+ # Model
1557
+ outputs = self.bert(
1558
+ input_ids,
1559
+ attention_mask,
1560
+ token_type_ids,
1561
+ position_ids,
1562
+ head_mask,
1563
+ deterministic=deterministic,
1564
+ output_attentions=output_attentions,
1565
+ output_hidden_states=output_hidden_states,
1566
+ return_dict=return_dict,
1567
+ )
1568
+
1569
+ hidden_states = outputs[0]
1570
+
1571
+ logits = self.qa_outputs(hidden_states)
1572
+ start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
1573
+ start_logits = start_logits.squeeze(-1)
1574
+ end_logits = end_logits.squeeze(-1)
1575
+
1576
+ if not return_dict:
1577
+ return (start_logits, end_logits) + outputs[1:]
1578
+
1579
+ return FlaxQuestionAnsweringModelOutput(
1580
+ start_logits=start_logits,
1581
+ end_logits=end_logits,
1582
+ hidden_states=outputs.hidden_states,
1583
+ attentions=outputs.attentions,
1584
+ )
1585
+
1586
+
1587
+ @add_start_docstrings(
1588
+ """
1589
+ Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1590
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1591
+ """,
1592
+ BERT_START_DOCSTRING,
1593
+ )
1594
+ class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
1595
+ module_class = FlaxBertForQuestionAnsweringModule
1596
+
1597
+
1598
+ append_call_sample_docstring(
1599
+ FlaxBertForQuestionAnswering,
1600
+ _CHECKPOINT_FOR_DOC,
1601
+ FlaxQuestionAnsweringModelOutput,
1602
+ _CONFIG_FOR_DOC,
1603
+ )
1604
+
1605
+
1606
+ class FlaxBertForCausalLMModule(nn.Module):
1607
+ config: BertConfig
1608
+ dtype: jnp.dtype = jnp.float32
1609
+ gradient_checkpointing: bool = False
1610
+
1611
+ def setup(self):
1612
+ self.bert = FlaxBertModule(
1613
+ config=self.config,
1614
+ add_pooling_layer=False,
1615
+ dtype=self.dtype,
1616
+ gradient_checkpointing=self.gradient_checkpointing,
1617
+ )
1618
+ self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
1619
+
1620
+ def __call__(
1621
+ self,
1622
+ input_ids,
1623
+ attention_mask,
1624
+ position_ids,
1625
+ token_type_ids: Optional[jnp.ndarray] = None,
1626
+ head_mask: Optional[jnp.ndarray] = None,
1627
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
1628
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1629
+ init_cache: bool = False,
1630
+ deterministic: bool = True,
1631
+ output_attentions: bool = False,
1632
+ output_hidden_states: bool = False,
1633
+ return_dict: bool = True,
1634
+ ):
1635
+ # Model
1636
+ outputs = self.bert(
1637
+ input_ids,
1638
+ attention_mask,
1639
+ token_type_ids,
1640
+ position_ids,
1641
+ head_mask,
1642
+ encoder_hidden_states=encoder_hidden_states,
1643
+ encoder_attention_mask=encoder_attention_mask,
1644
+ init_cache=init_cache,
1645
+ deterministic=deterministic,
1646
+ output_attentions=output_attentions,
1647
+ output_hidden_states=output_hidden_states,
1648
+ return_dict=return_dict,
1649
+ )
1650
+
1651
+ hidden_states = outputs[0]
1652
+ if self.config.tie_word_embeddings:
1653
+ shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
1654
+ else:
1655
+ shared_embedding = None
1656
+
1657
+ # Compute the prediction scores
1658
+ logits = self.cls(hidden_states, shared_embedding=shared_embedding)
1659
+
1660
+ if not return_dict:
1661
+ return (logits,) + outputs[1:]
1662
+
1663
+ return FlaxCausalLMOutputWithCrossAttentions(
1664
+ logits=logits,
1665
+ hidden_states=outputs.hidden_states,
1666
+ attentions=outputs.attentions,
1667
+ cross_attentions=outputs.cross_attentions,
1668
+ )
1669
+
1670
+
1671
+ @add_start_docstrings(
1672
+ """
1673
+ Bert Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
1674
+ autoregressive tasks.
1675
+ """,
1676
+ BERT_START_DOCSTRING,
1677
+ )
1678
+ class FlaxBertForCausalLM(FlaxBertPreTrainedModel):
1679
+ module_class = FlaxBertForCausalLMModule
1680
+
1681
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
1682
+ # initializing the cache
1683
+ batch_size, seq_length = input_ids.shape
1684
+
1685
+ past_key_values = self.init_cache(batch_size, max_length)
1686
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1687
+ # But since the decoder uses a causal mask, those positions are masked anyway.
1688
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
1689
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
1690
+ if attention_mask is not None:
1691
+ position_ids = attention_mask.cumsum(axis=-1) - 1
1692
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
1693
+ else:
1694
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
1695
+
1696
+ return {
1697
+ "past_key_values": past_key_values,
1698
+ "attention_mask": extended_attention_mask,
1699
+ "position_ids": position_ids,
1700
+ }
1701
+
1702
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
1703
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
1704
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
1705
+ return model_kwargs
1706
+
1707
+
1708
+ append_call_sample_docstring(
1709
+ FlaxBertForCausalLM,
1710
+ _CHECKPOINT_FOR_DOC,
1711
+ FlaxCausalLMOutputWithCrossAttentions,
1712
+ _CONFIG_FOR_DOC,
1713
+ )
1714
+
1715
+
1716
+ __all__ = [
1717
+ "FlaxBertForCausalLM",
1718
+ "FlaxBertForMaskedLM",
1719
+ "FlaxBertForMultipleChoice",
1720
+ "FlaxBertForNextSentencePrediction",
1721
+ "FlaxBertForPreTraining",
1722
+ "FlaxBertForQuestionAnswering",
1723
+ "FlaxBertForSequenceClassification",
1724
+ "FlaxBertForTokenClassification",
1725
+ "FlaxBertModel",
1726
+ "FlaxBertPreTrainedModel",
1727
+ ]