Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/__init__.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/beam_constraints.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/beam_search.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/candidate_generator.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/configuration_utils.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/flax_logits_process.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/flax_utils.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/stopping_criteria.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/streamers.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/tf_logits_process.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/watermarking.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__init__.py +26 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/__init__.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/cache.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/cache_manager.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/continuous_api.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/requests.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/scheduler.cpython-312.pyc +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/cache.py +606 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/cache_manager.py +226 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/continuous_api.py +1047 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/requests.py +204 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/scheduler.py +300 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/__init__.py +31 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/configuration_albert.py +170 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_albert.py +1349 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_flax_albert.py +1132 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_tf_albert.py +1572 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/tokenization_albert.py +320 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/tokenization_albert_fast.py +178 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/__init__.py +35 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py +882 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/configuration_auto.py +1404 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/feature_extraction_auto.py +422 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/image_processing_auto.py +688 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_auto.py +0 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_flax_auto.py +413 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_tf_auto.py +776 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/processing_auto.py +443 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/tokenization_auto.py +1235 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/video_processing_auto.py +393 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/__init__.py +28 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/configuration_bark.py +303 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/generation_configuration_bark.py +330 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/modeling_bark.py +1628 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/processing_bark.py +340 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/__init__.py +32 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/configuration_bert.py +154 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py +1801 -0
- URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_flax_bert.py +1727 -0
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (7.34 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/beam_constraints.cpython-312.pyc
ADDED
|
Binary file (23.1 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/beam_search.cpython-312.pyc
ADDED
|
Binary file (44.8 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/candidate_generator.cpython-312.pyc
ADDED
|
Binary file (61 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/configuration_utils.cpython-312.pyc
ADDED
|
Binary file (76.5 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/flax_logits_process.cpython-312.pyc
ADDED
|
Binary file (32.4 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/flax_utils.cpython-312.pyc
ADDED
|
Binary file (46.5 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/stopping_criteria.cpython-312.pyc
ADDED
|
Binary file (31.2 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/streamers.cpython-312.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/tf_logits_process.cpython-312.pyc
ADDED
|
Binary file (40.4 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/__pycache__/watermarking.cpython-312.pyc
ADDED
|
Binary file (28 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from .cache import PagedAttentionCache
|
| 16 |
+
from .continuous_api import ContinuousBatchingManager, ContinuousMixin
|
| 17 |
+
from .requests import RequestState, RequestStatus
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"ContinuousBatchingManager",
|
| 22 |
+
"ContinuousMixin",
|
| 23 |
+
"PagedAttentionCache",
|
| 24 |
+
"RequestState",
|
| 25 |
+
"RequestStatus",
|
| 26 |
+
]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (501 Bytes). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/cache.cpython-312.pyc
ADDED
|
Binary file (31.7 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/cache_manager.cpython-312.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/continuous_api.cpython-312.pyc
ADDED
|
Binary file (55.7 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/requests.cpython-312.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/__pycache__/scheduler.cpython-312.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/cache.py
ADDED
|
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from collections import deque
|
| 16 |
+
from math import floor, gcd, sqrt
|
| 17 |
+
from typing import Optional, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from ...configuration_utils import PretrainedConfig
|
| 22 |
+
from ...generation.configuration_utils import GenerationConfig
|
| 23 |
+
from ...utils.metrics import attach_tracer, traced
|
| 24 |
+
from .cache_manager import CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
|
| 25 |
+
from .requests import get_device_and_memory_breakdown, logger
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def group_layers_by_attn_type(config: PretrainedConfig) -> tuple[list[list[int]], list[str]]:
|
| 29 |
+
"""
|
| 30 |
+
Group layers depending on the attention mix, according to VLLM's hybrid allocator rules:
|
| 31 |
+
- Layers in each group need to have the same type of attention
|
| 32 |
+
- All groups have the same number of layers
|
| 33 |
+
|
| 34 |
+
For a model with the following layer types: ["sliding", "full", "full", "sliding", "full", "full", "full", "full"]
|
| 35 |
+
We would get two groups: [0, 3] and [1, 2], [4,5], [6,7].
|
| 36 |
+
"""
|
| 37 |
+
# If the config has no layer_type attribute, it means all layers are the same attention type
|
| 38 |
+
layer_types = getattr(config, "layer_types", None)
|
| 39 |
+
if layer_types is None:
|
| 40 |
+
attn_type = "sliding_attention" if getattr(config, "sliding_window", None) is not None else "full_attention"
|
| 41 |
+
layer_types = [attn_type for _ in range(config.num_hidden_layers)]
|
| 42 |
+
|
| 43 |
+
# We then count the number of layers of each type
|
| 44 |
+
layer_counts = {}
|
| 45 |
+
for i, layer_type in enumerate(layer_types):
|
| 46 |
+
layer_counts[layer_type] = layer_counts.get(layer_type, []) + [i]
|
| 47 |
+
|
| 48 |
+
# The size of all groups is the greatest common divisor of the number of layers of each type
|
| 49 |
+
group_size = gcd(*[len(indices) for indices in layer_counts.values()])
|
| 50 |
+
|
| 51 |
+
# We then group the layers by type
|
| 52 |
+
layer_groups = []
|
| 53 |
+
for layer_type, indices in layer_counts.items():
|
| 54 |
+
for i in range(0, len(indices), group_size):
|
| 55 |
+
layer_groups.append(indices[i : i + group_size])
|
| 56 |
+
# And note the layer types
|
| 57 |
+
group_types = [layer_types[lg[0]] for lg in layer_groups]
|
| 58 |
+
return layer_groups, group_types
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@attach_tracer()
|
| 62 |
+
class PagedAttentionCache:
|
| 63 |
+
"""
|
| 64 |
+
Manages the cache for a paged attention mechanism, inspired by VLLM's hybrid allocator. The cache relies on making
|
| 65 |
+
groups of layers to reduce the complexity of cache management and fragmentation.
|
| 66 |
+
|
| 67 |
+
The cache uses a three-level hierarchy:
|
| 68 |
+
- Pages: The smallest unit of cache, a page has a size of [num_heads, head_size], which is the space needed to
|
| 69 |
+
store the key or value states for one token and one layer. For a model with only full-attention layers, to store
|
| 70 |
+
the KV cache of one token, we need `2 * num_layers` pages: key and values each take `num_layers` pages.
|
| 71 |
+
Pages are grouped into blocks:
|
| 72 |
+
- Blocks: A block is a collection of `block_size` pages, serving as the allocation unit to reduce management
|
| 73 |
+
complexity and fragmentation. Cache is allocated and freed block by block, not page by page. One block is
|
| 74 |
+
allocated to one layer group, which only has one attention type, like full-attention or sliding-attention.
|
| 75 |
+
If all layers in the model have the same attention type, then all layers will be in the same group. There is
|
| 76 |
+
more than one group if and only if the model has a mixed attention types, like layers with full-attention and
|
| 77 |
+
layers with sliding-attention.
|
| 78 |
+
- Cache tensors: The physical supports for the cache. There are as many cache tensors as there are layer in a
|
| 79 |
+
layer group, and the shape of the cache tensor is `[num_blocks * block_size, num_heads, head_size]`.
|
| 80 |
+
|
| 81 |
+
Grouping layers into groups is useful because when we allocate one block to a group N, the block allocated is the
|
| 82 |
+
same for all layers in group N, equivalently it is allocated across all cache tensors. This allows us to
|
| 83 |
+
efficiently allocate and free blocks, and to efficiently read and write key and value states.
|
| 84 |
+
|
| 85 |
+
For instance, imagine we have 8 blocks of cache and a model with two layer groups: a full-attention group with 3
|
| 86 |
+
layers and a sliding-attention group with 3 layers. At creation time, the physical cache tensors look like this:
|
| 87 |
+
|
| 88 |
+
cache_tensor_0: □ □ □ □ □ □ □ □
|
| 89 |
+
cache_tensor_1: □ □ □ □ □ □ □ □
|
| 90 |
+
cache_tensor_2: □ □ □ □ □ □ □ □
|
| 91 |
+
|
| 92 |
+
where □ means the blocks is not allocated to any layer group yet. We have 3 cache tensors because there are
|
| 93 |
+
3 layers per group.
|
| 94 |
+
We allocate 1 block to each group, after allocation, the cache tensors look like this:
|
| 95 |
+
|
| 96 |
+
cache_tensor_0: ✖ ◉ □ □ □ □ □ □
|
| 97 |
+
cache_tensor_1: ✖ ◉ □ □ □ □ □ □
|
| 98 |
+
cache_tensor_2: ✖ ◉ □ □ □ □ □ □
|
| 99 |
+
|
| 100 |
+
where ✖ means the block is allocated to the full-attention group, and ◉ means the block is allocated to the
|
| 101 |
+
sliding-attention group.
|
| 102 |
+
Now, if we continue to generate, and the sliding window has been reached, we only need to allocate a new block
|
| 103 |
+
for the full-attention group, and the cache tensors look like this:
|
| 104 |
+
|
| 105 |
+
cache_tensor_0: ✖ ◉ ✖ □ □ □ □ □
|
| 106 |
+
cache_tensor_1: ✖ ◉ ✖ □ □ □ □ □
|
| 107 |
+
cache_tensor_2: ✖ ◉ ✖ □ □ □ □ □
|
| 108 |
+
|
| 109 |
+
And after further generation, when we need a new block allocated:
|
| 110 |
+
|
| 111 |
+
cache_tensor_0: ✖ ◉ ✖ ✖ □ □ □ □
|
| 112 |
+
cache_tensor_1: ✖ ◉ ✖ ✖ □ □ □ □
|
| 113 |
+
cache_tensor_2: ✖ ◉ ✖ ✖ □ □ □ □
|
| 114 |
+
|
| 115 |
+
This would not have been possible if all layers were in the same group: we would have had to allocate a new block
|
| 116 |
+
for the sliding-attention group, although it is not needed.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
# TODO: this init is quite long, maybe a refactor is in order
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
config: PretrainedConfig,
|
| 123 |
+
generation_config: GenerationConfig,
|
| 124 |
+
device: torch.device,
|
| 125 |
+
dtype: torch.dtype = torch.float16,
|
| 126 |
+
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
|
| 127 |
+
tp_size: Optional[int] = None,
|
| 128 |
+
) -> None:
|
| 129 |
+
"""Initialize a paged attention cache for efficient memory usage.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
config: Model configuration
|
| 133 |
+
generation_config: Generation configuration containing cache parameters
|
| 134 |
+
device: Device for the cache tensors
|
| 135 |
+
dtype: Data type of the cache
|
| 136 |
+
layer_device_map: Optional mapping of layer indices to devices
|
| 137 |
+
tp_size: Tensor parallelism size
|
| 138 |
+
"""
|
| 139 |
+
self.config = config
|
| 140 |
+
self.dtype = dtype
|
| 141 |
+
self.device = device
|
| 142 |
+
|
| 143 |
+
# Extract model dimensions
|
| 144 |
+
kv_heads = getattr(config, "num_key_value_heads", None)
|
| 145 |
+
self.num_key_value_heads: int = kv_heads if kv_heads is not None else config.num_attention_heads
|
| 146 |
+
head_dim = getattr(config, "head_dim", None)
|
| 147 |
+
self.head_dim: int = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads
|
| 148 |
+
|
| 149 |
+
# Extract cache dimensions
|
| 150 |
+
self.block_size = getattr(generation_config, "block_size", 32)
|
| 151 |
+
|
| 152 |
+
# Group layers depending on the attention mix
|
| 153 |
+
layer_groups, group_types = group_layers_by_attn_type(config)
|
| 154 |
+
group_size = len(layer_groups[0])
|
| 155 |
+
self.num_groups = len(layer_groups)
|
| 156 |
+
|
| 157 |
+
self.sliding_windows = {}
|
| 158 |
+
self.layer_index_to_group_indices = {}
|
| 159 |
+
for i, group in enumerate(layer_groups):
|
| 160 |
+
sliding_window = config.sliding_window if group_types[i] == "sliding_attention" else 1
|
| 161 |
+
for j, layer in enumerate(group):
|
| 162 |
+
self.layer_index_to_group_indices[layer] = (i, j)
|
| 163 |
+
self.sliding_windows[layer] = sliding_window
|
| 164 |
+
|
| 165 |
+
# Handle TP (or dont)
|
| 166 |
+
if tp_size is not None and tp_size > 1:
|
| 167 |
+
if self.num_key_value_heads % tp_size != 0:
|
| 168 |
+
raise ValueError(
|
| 169 |
+
f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
|
| 170 |
+
)
|
| 171 |
+
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
|
| 172 |
+
# self.num_key_value_heads //= tp_size # TODO: why is this commented out?
|
| 173 |
+
|
| 174 |
+
# Infer number of blocks and max batch tokens
|
| 175 |
+
page_size = self.head_dim * self.num_key_value_heads
|
| 176 |
+
|
| 177 |
+
if getattr(config, "attn_implementation", None) == "paged_attention":
|
| 178 |
+
num_attention_masks = 0
|
| 179 |
+
else:
|
| 180 |
+
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
|
| 181 |
+
num_attention_masks = 2 if "sliding_attention" in group_types else 1
|
| 182 |
+
|
| 183 |
+
memory_handler = PagedAttentionMemoryHandler(
|
| 184 |
+
block_size=self.block_size,
|
| 185 |
+
page_size=page_size,
|
| 186 |
+
num_groups=self.num_groups,
|
| 187 |
+
group_size=group_size,
|
| 188 |
+
peak_activation_per_token=(config.hidden_size + config.vocab_size),
|
| 189 |
+
num_attention_masks=num_attention_masks,
|
| 190 |
+
)
|
| 191 |
+
num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens(
|
| 192 |
+
num_blocks=getattr(generation_config, "num_blocks", None),
|
| 193 |
+
max_batch_tokens=getattr(generation_config, "max_batch_tokens", None),
|
| 194 |
+
max_memory_percent=getattr(generation_config, "max_memory", 0.9),
|
| 195 |
+
cache_dtype=self.dtype,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Add the inferred attributes to the class
|
| 199 |
+
self.num_blocks = num_blocks
|
| 200 |
+
self.max_batch_tokens = max_batch_tokens
|
| 201 |
+
logger.info(
|
| 202 |
+
f"PagedAttentionCache initialized with {self.num_blocks = }, {self.block_size = }, {page_size = }, "
|
| 203 |
+
f"{self.max_batch_tokens = } {num_attention_masks = }"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Initialize the cache
|
| 207 |
+
self.key_cache: list[torch.Tensor] = []
|
| 208 |
+
self.value_cache: list[torch.Tensor] = []
|
| 209 |
+
# We add one extra token to the cache to handle padding and generally discard unwanted tokens
|
| 210 |
+
self.cache_shape = (num_blocks * self.block_size + 1, self.num_key_value_heads, self.head_dim)
|
| 211 |
+
for _ in range(group_size):
|
| 212 |
+
new_layer_key_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
|
| 213 |
+
new_layer_value_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
|
| 214 |
+
torch._dynamo.mark_static_address(new_layer_key_cache)
|
| 215 |
+
torch._dynamo.mark_static_address(new_layer_value_cache)
|
| 216 |
+
self.key_cache.append(new_layer_key_cache)
|
| 217 |
+
self.value_cache.append(new_layer_value_cache)
|
| 218 |
+
logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")
|
| 219 |
+
|
| 220 |
+
# Block management data structures
|
| 221 |
+
self._free_blocks = deque(range(num_blocks))
|
| 222 |
+
self.group_cache_managers: list[CacheAllocator] = []
|
| 223 |
+
for i, group_type in enumerate(group_types):
|
| 224 |
+
if group_type == "full_attention":
|
| 225 |
+
cm = FullAttentionCacheAllocator(i, self.block_size)
|
| 226 |
+
elif group_type == "sliding_attention":
|
| 227 |
+
cm = SlidingAttentionCacheAllocator(i, self.block_size, config.sliding_window)
|
| 228 |
+
else:
|
| 229 |
+
raise ValueError(f"Invalid group type: {group_type}")
|
| 230 |
+
self.group_cache_managers.append(cm)
|
| 231 |
+
|
| 232 |
+
@traced
|
| 233 |
+
def allocate_blocks(self, n_blocks: int, request_id: str) -> int:
|
| 234 |
+
"""Allocate cache blocks across all layer groups for a given request. Actual allocation is done by the cache
|
| 235 |
+
managers, and this method only returns the maximum number of blocks actually allocated across all managers."""
|
| 236 |
+
max_allocated = 0
|
| 237 |
+
for cm in self.group_cache_managers:
|
| 238 |
+
allocated = cm.allocate_blocks(n_blocks, request_id, self._free_blocks)
|
| 239 |
+
if allocated is None:
|
| 240 |
+
return None
|
| 241 |
+
max_allocated = max(max_allocated, allocated)
|
| 242 |
+
return max_allocated
|
| 243 |
+
|
| 244 |
+
@traced
|
| 245 |
+
def free_blocks(self, request_id: str) -> None:
|
| 246 |
+
"""Free all allocated cache blocks for a given request across all layer groups. Actual deallocation is done
|
| 247 |
+
by the cache managers."""
|
| 248 |
+
for cm in self.group_cache_managers:
|
| 249 |
+
cm.free_blocks(request_id, self._free_blocks)
|
| 250 |
+
|
| 251 |
+
def get_num_free_blocks(self) -> int:
|
| 252 |
+
"""Get the current number of unallocated blocks available for new requests."""
|
| 253 |
+
return len(self._free_blocks)
|
| 254 |
+
|
| 255 |
+
@traced
|
| 256 |
+
def extend_read_indices(
|
| 257 |
+
self, request_id: str, past_length: int, query_length: int, read_index: list[list[int]]
|
| 258 |
+
) -> None:
|
| 259 |
+
"""Retrieve physical cache indices for reading KV states in the cache across all layer groups. This method
|
| 260 |
+
coordinates with all cache managers to build the complete set of read indices needed for attention computation.
|
| 261 |
+
"""
|
| 262 |
+
for cm, read_indices in zip(self.group_cache_managers, read_index):
|
| 263 |
+
indices = cm.get_read_indices(request_id, past_length, query_length)
|
| 264 |
+
read_indices.extend(indices)
|
| 265 |
+
|
| 266 |
+
@traced
|
| 267 |
+
def extend_write_indices(
|
| 268 |
+
self, request_id: str, past_length: int, query_length: int, write_index: list[list[int]]
|
| 269 |
+
) -> None:
|
| 270 |
+
"""Retrieve physical cache indices for writing new KV states to the cache across all layer groups. This method
|
| 271 |
+
coordinates with all cache managers to build the complete set of write indices needed to store computed KV
|
| 272 |
+
states."""
|
| 273 |
+
for cm, write_indices in zip(self.group_cache_managers, write_index):
|
| 274 |
+
indices = cm.get_write_indices(request_id, past_length, query_length)
|
| 275 |
+
write_indices.extend(indices)
|
| 276 |
+
|
| 277 |
+
@traced
|
| 278 |
+
def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> dict[str, int]:
|
| 279 |
+
"""Retrieve the key sequence length for the given request_id across all layer types. Returns a dictionary of
|
| 280 |
+
layer types to their corresponding key sequence lengths."""
|
| 281 |
+
seqlens_k = {}
|
| 282 |
+
for cm in self.group_cache_managers:
|
| 283 |
+
attn_type, seqlen_k = cm.get_seqlens_k(request_id, past_length, query_length)
|
| 284 |
+
seqlens_k[attn_type] = seqlen_k
|
| 285 |
+
return seqlens_k
|
| 286 |
+
|
| 287 |
+
@traced
|
| 288 |
+
def update(
|
| 289 |
+
self,
|
| 290 |
+
key_states: torch.Tensor, # shape [1, num_kv_heads, seqlen_kv, head_dim]
|
| 291 |
+
value_states: torch.Tensor, # shape [1, num_kv_heads, seqlen_kv, head_dim]
|
| 292 |
+
layer_idx: int,
|
| 293 |
+
read_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_kv + past_length]
|
| 294 |
+
write_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_q]
|
| 295 |
+
**kwargs,
|
| 296 |
+
) -> tuple[torch.Tensor, torch.Tensor]: # shape [seqlen_kv + past_length, num_kv_heads, head_dim]
|
| 297 |
+
"""Update the cache with new key-value states for a specific layer. This method writes new KV states to the
|
| 298 |
+
appropriate cache locations. The behavior differs based on the layer's attention type:
|
| 299 |
+
|
| 300 |
+
- Full attention: New KV states are written to cache, then complete sequence is read from cache
|
| 301 |
+
- Sliding window: Old KV is read from cache along with extra spaces for the new KV, then new KV is written to
|
| 302 |
+
cache. This is because new KV might overwrite the old KV, so we need to read the old KV first.
|
| 303 |
+
|
| 304 |
+
Returns the complete KV states (cached + new) for attention computation.
|
| 305 |
+
"""
|
| 306 |
+
# Retrieve the layer read and write indices, and if there is a sliding window
|
| 307 |
+
group_idx, layer_idx_in_group = self.layer_index_to_group_indices[layer_idx]
|
| 308 |
+
layer_read_index = read_index[group_idx]
|
| 309 |
+
layer_write_index = write_index[group_idx]
|
| 310 |
+
# Select the correct cache
|
| 311 |
+
k_cache = self.key_cache[layer_idx_in_group]
|
| 312 |
+
v_cache = self.value_cache[layer_idx_in_group]
|
| 313 |
+
# Transpose the key and value states to match the cache shape, after which shape is [seqlen_kv, num_kv_heads, head_dim]
|
| 314 |
+
key_states = key_states.transpose(1, 2).squeeze(0)
|
| 315 |
+
value_states = value_states.transpose(1, 2).squeeze(0)
|
| 316 |
+
|
| 317 |
+
# Case: full attention
|
| 318 |
+
sliding_window = self.sliding_windows[layer_idx]
|
| 319 |
+
if sliding_window == 1:
|
| 320 |
+
k_cache[layer_write_index, :, :] = key_states
|
| 321 |
+
v_cache[layer_write_index, :, :] = value_states
|
| 322 |
+
key_states_with_cache = k_cache[layer_read_index, :, :]
|
| 323 |
+
value_states_with_cache = v_cache[layer_read_index, :, :]
|
| 324 |
+
|
| 325 |
+
# Case: sliding window -- we need to be careful of read/write order because of chunked prefill, because it's
|
| 326 |
+
# the only case where you may write over cache you need to use
|
| 327 |
+
else:
|
| 328 |
+
# Add the cache to the key and value states
|
| 329 |
+
mask = layer_read_index == -1 # TODO: can this can be efficiently precomputed?
|
| 330 |
+
key_states_with_cache = k_cache[layer_read_index, :, :]
|
| 331 |
+
key_states_with_cache[mask] = key_states
|
| 332 |
+
value_states_with_cache = v_cache[layer_read_index, :, :]
|
| 333 |
+
value_states_with_cache[mask] = value_states
|
| 334 |
+
# Write new KV values to the cache
|
| 335 |
+
k_cache[layer_write_index, :, :] = key_states
|
| 336 |
+
v_cache[layer_write_index, :, :] = value_states
|
| 337 |
+
|
| 338 |
+
# Return the new KV values
|
| 339 |
+
return key_states_with_cache, value_states_with_cache
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# TODO: rework computation with the groups and their sizes
|
| 343 |
+
class PagedAttentionMemoryHandler:
|
| 344 |
+
"""A helper class to determine the best number of pages and maximum number of tokens per batch for the paged
|
| 345 |
+
attention cache, providing automatic sizing based on available GPU memory.
|
| 346 |
+
The helper works using the number of pages, which is tied to the number of blocks by:
|
| 347 |
+
num_blocks = num_pages // block_size
|
| 348 |
+
|
| 349 |
+
The memory footprint consists of three main components:
|
| 350 |
+
- Cache memory: the space needed to store the cache tensors:
|
| 351 |
+
2 * layer_group_size * [num_pages, page_size] * cache_dtype
|
| 352 |
+
- Activation memory: the space temporarily taken by the largest activation during the model forward pass:
|
| 353 |
+
peak_activation_per_token * max_tokens_per_batch * activation_dtype_size
|
| 354 |
+
- Static tensors: the space taken by the input/output buffers and metadata tensors for batch processing, sum of:
|
| 355 |
+
- inputs_ids + outputs_ids + position_ids + logits_indices: 4 * max_tokens_per_batch * int32_size
|
| 356 |
+
- attention_mask: num_attention_masks * num_pages * max_tokens_per_batch * activation_dtype_size
|
| 357 |
+
- cumulative_seqlens_q + cumulative_seqlens_k: (1 + 2) * max_tokens_per_batch * int32_size
|
| 358 |
+
- write_index_tensor: num_groups * max_tokens_per_batch * int32_size
|
| 359 |
+
- read_index_tensor: num_groups * (num_pages + max_tokens_per_batch) * int32_size
|
| 360 |
+
|
| 361 |
+
The handler can operate in three modes:
|
| 362 |
+
1. Auto-sizing: Determines both number of pages and maximum number of tokens per batch using quadratic optimization
|
| 363 |
+
2. Fixed cache: Calculates max batch tokens given a fixed number of pages
|
| 364 |
+
3. Fixed batch: Calculates number of pages given a fixed maximum batch size
|
| 365 |
+
|
| 366 |
+
"""
|
| 367 |
+
|
| 368 |
+
_activation_dtype = torch.bfloat16
|
| 369 |
+
_input_dtype = torch.int32
|
| 370 |
+
_upper_bound_max_batch_tokens = 256
|
| 371 |
+
_upper_bound_num_blocks = 4096
|
| 372 |
+
|
| 373 |
+
def __init__(
|
| 374 |
+
self,
|
| 375 |
+
block_size: int,
|
| 376 |
+
page_size: int,
|
| 377 |
+
num_groups: int,
|
| 378 |
+
group_size: int,
|
| 379 |
+
peak_activation_per_token: int,
|
| 380 |
+
num_attention_masks: int,
|
| 381 |
+
) -> None:
|
| 382 |
+
"""Initialize the memory handler with the parameters that cannot be automatically inferred.
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
block_size: Size of the cache blocks
|
| 386 |
+
page_size: Size of the cache pages
|
| 387 |
+
num_groups: Number of layer groups
|
| 388 |
+
group_size: Number of layers per layer group
|
| 389 |
+
peak_activation_per_token: Maximum size of activation tensor per token, = hidden_size + vocab_size
|
| 390 |
+
num_attention_masks: Number of attention masks, 0 if no attention mask is used, 2 if hybrid model, else 1
|
| 391 |
+
"""
|
| 392 |
+
self.block_size = block_size
|
| 393 |
+
self.page_size = page_size
|
| 394 |
+
self.num_groups = num_groups
|
| 395 |
+
self.group_size = group_size
|
| 396 |
+
self.peak_activation_per_token = peak_activation_per_token
|
| 397 |
+
self.num_attention_masks = num_attention_masks
|
| 398 |
+
|
| 399 |
+
@staticmethod
|
| 400 |
+
def get_available_memory(max_memory_percent: float = 1.0) -> int:
|
| 401 |
+
"""Calculate available GPU memory for cache allocation, accounting for already allocated tensors.
|
| 402 |
+
This method queries the current memory state and applies the specified percentage limit to determine
|
| 403 |
+
how much memory can be safely used for the paged attention cache.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
max_memory_percent: Fraction of available memory to use (0.0-1.0). 1.0 means use all available memory.
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
int: Available memory in bytes for cache allocation
|
| 410 |
+
"""
|
| 411 |
+
_, total, reserved, allocated = get_device_and_memory_breakdown()
|
| 412 |
+
available_memory = total - max(allocated, reserved)
|
| 413 |
+
available_memory = int(available_memory * max_memory_percent)
|
| 414 |
+
return available_memory
|
| 415 |
+
|
| 416 |
+
def infer_num_blocks_and_max_batch_tokens(
|
| 417 |
+
self,
|
| 418 |
+
num_blocks: Optional[int] = None,
|
| 419 |
+
max_batch_tokens: Optional[int] = None,
|
| 420 |
+
max_memory_percent: float = 0.9,
|
| 421 |
+
cache_dtype: torch.dtype = torch.float16,
|
| 422 |
+
) -> tuple[int, int]:
|
| 423 |
+
"""Determine optimal number of blocks and maximum number of tokens per batch based on available memory and
|
| 424 |
+
constraints. Check the class docstring for more details. Naming the number of pages as N and the maximum number
|
| 425 |
+
of tokens per batch as M, the equation solved is:
|
| 426 |
+
|
| 427 |
+
available_memory = sum([
|
| 428 |
+
MN * num_attention_masks * activation_dtype_size,
|
| 429 |
+
2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
|
| 430 |
+
M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
|
| 431 |
+
])
|
| 432 |
+
|
| 433 |
+
where we already simplified int32_size = 4.
|
| 434 |
+
"""
|
| 435 |
+
# If neither num_blocks nor max_batch_tokens are provided, we use a second-order polynomial
|
| 436 |
+
if num_blocks is None and max_batch_tokens is None:
|
| 437 |
+
num_blocks, max_batch_tokens = self.compute_num_blocks_and_max_batch_tokens(
|
| 438 |
+
max_memory_percent, cache_dtype
|
| 439 |
+
)
|
| 440 |
+
# If only num_blocks is provided, we infer the max_batch_tokens
|
| 441 |
+
elif num_blocks is not None and max_batch_tokens is None:
|
| 442 |
+
max_batch_tokens = self.compute_max_batch_tokens(num_blocks, max_memory_percent, cache_dtype)
|
| 443 |
+
# If only max_batch_tokens is provided, we infer the num_blocks
|
| 444 |
+
elif max_batch_tokens is not None and num_blocks is None:
|
| 445 |
+
num_blocks = self.compute_num_blocks(max_batch_tokens, max_memory_percent, cache_dtype)
|
| 446 |
+
|
| 447 |
+
# We check if the memory footprint is too large in all cases
|
| 448 |
+
available_memory = self.get_available_memory(max_memory_percent)
|
| 449 |
+
memory_footprint = self.compute_memory_footprint(
|
| 450 |
+
max_batch_tokens=max_batch_tokens,
|
| 451 |
+
num_blocks=num_blocks,
|
| 452 |
+
cache_dtype=cache_dtype,
|
| 453 |
+
)
|
| 454 |
+
if memory_footprint > available_memory:
|
| 455 |
+
raise MemoryError(f"Memory footprint {memory_footprint} is more than available memory {available_memory}")
|
| 456 |
+
return num_blocks, max_batch_tokens
|
| 457 |
+
|
| 458 |
+
def compute_num_blocks_and_max_batch_tokens(
|
| 459 |
+
self,
|
| 460 |
+
max_memory_percent: float = 0.9,
|
| 461 |
+
cache_dtype: torch.dtype = torch.float16,
|
| 462 |
+
m: float = 0.01,
|
| 463 |
+
) -> tuple[int, int]:
|
| 464 |
+
"""Calculate optimal number of blocks and maximum number of tokens per batch using quadratic optimization when
|
| 465 |
+
neither is fixed. This method assumes a relationship M = m * N where m is a small ratio below 1 and solves the
|
| 466 |
+
resulting quadratic equation to find the optimal N that maximizes utilization within memory constraints. m is
|
| 467 |
+
the amount of cache we can fill with one batch: m=0.01 means a batch fills at most 1% of the cache. The equation
|
| 468 |
+
to solve is:
|
| 469 |
+
|
| 470 |
+
available_memory = sum([
|
| 471 |
+
m * N^2 * num_attention_masks * activation_dtype_size,
|
| 472 |
+
2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
|
| 473 |
+
m * N * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
|
| 474 |
+
])
|
| 475 |
+
"""
|
| 476 |
+
cache_memory = self.get_available_memory(max_memory_percent)
|
| 477 |
+
logger.info(f"Cache memory: {cache_memory}")
|
| 478 |
+
|
| 479 |
+
# Compute second-degree polynomial coefficients
|
| 480 |
+
a = m * self.num_attention_masks * self._activation_dtype.itemsize
|
| 481 |
+
b = 2 * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
|
| 482 |
+
b += m * (self.peak_activation_per_token * self._activation_dtype.itemsize + 28 + 4 * self.num_groups)
|
| 483 |
+
c = -cache_memory
|
| 484 |
+
logger.debug(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }")
|
| 485 |
+
|
| 486 |
+
# Compute discriminant and greatest solution
|
| 487 |
+
discriminant = b**2 - 4 * a * c
|
| 488 |
+
if discriminant < 0:
|
| 489 |
+
raise ValueError(f"Discriminant is negative: {discriminant = }")
|
| 490 |
+
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
|
| 491 |
+
if greatest_solution < 0:
|
| 492 |
+
raise ValueError(f"Greatest solution is negative: {greatest_solution = }")
|
| 493 |
+
|
| 494 |
+
# Infer number of blocks and max batch tokens
|
| 495 |
+
num_pages = floor(greatest_solution)
|
| 496 |
+
num_blocks = num_pages // self.block_size
|
| 497 |
+
if num_blocks > self._upper_bound_num_blocks:
|
| 498 |
+
logger.info(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }")
|
| 499 |
+
num_blocks = self._upper_bound_num_blocks
|
| 500 |
+
max_batch_tokens = int(greatest_solution * m)
|
| 501 |
+
if max_batch_tokens > self._upper_bound_max_batch_tokens:
|
| 502 |
+
logger.info(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }")
|
| 503 |
+
max_batch_tokens = self._upper_bound_max_batch_tokens
|
| 504 |
+
return num_blocks, max_batch_tokens
|
| 505 |
+
|
| 506 |
+
def compute_max_batch_tokens(
|
| 507 |
+
self,
|
| 508 |
+
num_blocks: int,
|
| 509 |
+
max_memory_percent: float = 0.9,
|
| 510 |
+
cache_dtype: torch.dtype = torch.float16,
|
| 511 |
+
) -> int:
|
| 512 |
+
"""Calculate maximum batch tokens M given a fixed number of cache blocks. The formula for M is given by:
|
| 513 |
+
|
| 514 |
+
M = (available_memory - 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group))
|
| 515 |
+
/ (activation_dtype_size * (N * num_attention_masks + peak_activation_per_token) + 28 + 4 * num_group)
|
| 516 |
+
"""
|
| 517 |
+
cache_memory = self.get_available_memory(max_memory_percent)
|
| 518 |
+
num_pages = num_blocks * self.block_size
|
| 519 |
+
# Compute numerator
|
| 520 |
+
num = cache_memory
|
| 521 |
+
num -= 2 * num_pages * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
|
| 522 |
+
# Compute denominator
|
| 523 |
+
denum = self._activation_dtype.itemsize * (
|
| 524 |
+
num_pages * self.num_attention_masks + self.peak_activation_per_token
|
| 525 |
+
)
|
| 526 |
+
denum += 28 + 4 * self.num_groups
|
| 527 |
+
# Compute max batch tokens and return
|
| 528 |
+
max_batch_tokens = floor(num / denum)
|
| 529 |
+
if max_batch_tokens > self._upper_bound_max_batch_tokens:
|
| 530 |
+
logger.info(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }")
|
| 531 |
+
max_batch_tokens = self._upper_bound_max_batch_tokens
|
| 532 |
+
return max_batch_tokens
|
| 533 |
+
|
| 534 |
+
def compute_num_blocks(
|
| 535 |
+
self,
|
| 536 |
+
max_batch_tokens: int,
|
| 537 |
+
max_memory_percent: float = 0.9,
|
| 538 |
+
cache_dtype: torch.dtype = torch.float16,
|
| 539 |
+
) -> int:
|
| 540 |
+
"""Calculate number of cache blocks N given a fixed maximum token per token M. The formula for N is given by:
|
| 541 |
+
|
| 542 |
+
N = (available_memory - M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group))
|
| 543 |
+
/ (2 * (layer_group_size * page_size * cache_dtype + 2 * num_group) + M * (num_attention_masks * activation_dtype_size))
|
| 544 |
+
"""
|
| 545 |
+
cache_memory = self.get_available_memory(max_memory_percent)
|
| 546 |
+
# Compute numerator
|
| 547 |
+
num = cache_memory
|
| 548 |
+
num -= max_batch_tokens * self.peak_activation_per_token * self._activation_dtype.itemsize
|
| 549 |
+
num -= max_batch_tokens * (28 + 4 * self.num_groups)
|
| 550 |
+
# Compute denominator
|
| 551 |
+
denum = 2 * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
|
| 552 |
+
denum += max_batch_tokens * (self.num_attention_masks * self._activation_dtype.itemsize)
|
| 553 |
+
denum += max_batch_tokens * self._activation_dtype.itemsize
|
| 554 |
+
# Compute cache size and return number of blocks
|
| 555 |
+
num_pages = floor(num / denum)
|
| 556 |
+
num_blocks = num_pages // self.block_size
|
| 557 |
+
if num_blocks > self._upper_bound_num_blocks:
|
| 558 |
+
logger.info(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }")
|
| 559 |
+
num_blocks = self._upper_bound_num_blocks
|
| 560 |
+
return num_blocks
|
| 561 |
+
|
| 562 |
+
def compute_memory_footprint(
|
| 563 |
+
self,
|
| 564 |
+
num_blocks: Optional[int] = None,
|
| 565 |
+
max_batch_tokens: Optional[int] = None,
|
| 566 |
+
cache_dtype: torch.dtype = torch.float16,
|
| 567 |
+
) -> tuple[int, int, int]:
|
| 568 |
+
"""Calculate the memory footprint breakdown for a given number of blocks and maximum batch tokens. The memory
|
| 569 |
+
footprint is given by:
|
| 570 |
+
|
| 571 |
+
available_memory = sum([
|
| 572 |
+
MN * num_attention_masks * activation_dtype_size,
|
| 573 |
+
2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
|
| 574 |
+
M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
|
| 575 |
+
])
|
| 576 |
+
but is broken down below.
|
| 577 |
+
"""
|
| 578 |
+
num_pages = num_blocks * self.block_size
|
| 579 |
+
|
| 580 |
+
cache_memory_footprint = 2 * self.group_size * num_pages * self.page_size * cache_dtype.itemsize
|
| 581 |
+
|
| 582 |
+
activation_memory_footprint = self.peak_activation_per_token * self._activation_dtype.itemsize
|
| 583 |
+
activation_memory_footprint *= max_batch_tokens
|
| 584 |
+
|
| 585 |
+
inputs_outputs_positions_and_logits_memory_footprint = 4 * max_batch_tokens * 4 # second 4 is for int32 size
|
| 586 |
+
|
| 587 |
+
attention_memory_footprint = self.num_attention_masks * self._activation_dtype.itemsize
|
| 588 |
+
attention_memory_footprint *= num_pages * max_batch_tokens
|
| 589 |
+
|
| 590 |
+
cumulative_seqlens_memory_footprint = 3 * max_batch_tokens * 4 # 4 is for int32 size
|
| 591 |
+
|
| 592 |
+
write_index_memory_footprint = self.num_groups * max_batch_tokens * 4 # 4 is for int32 size
|
| 593 |
+
read_index_memory_footprint = self.num_groups * (num_pages + max_batch_tokens) * 4 # 4 is for int32 size
|
| 594 |
+
|
| 595 |
+
total_memory_footprint = sum(
|
| 596 |
+
[
|
| 597 |
+
cache_memory_footprint,
|
| 598 |
+
activation_memory_footprint,
|
| 599 |
+
inputs_outputs_positions_and_logits_memory_footprint,
|
| 600 |
+
attention_memory_footprint,
|
| 601 |
+
cumulative_seqlens_memory_footprint,
|
| 602 |
+
write_index_memory_footprint,
|
| 603 |
+
read_index_memory_footprint,
|
| 604 |
+
]
|
| 605 |
+
)
|
| 606 |
+
return total_memory_footprint
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/cache_manager.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from abc import ABC, abstractmethod
|
| 16 |
+
from collections import deque
|
| 17 |
+
from math import ceil
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
from .requests import logger
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CacheAllocator(ABC):
|
| 24 |
+
"""Abstract base class for cache managers. Cache managers keep track of per-request cache allocations, determine
|
| 25 |
+
when a new physical block needs to be allocated and compute physical indices for reading or writing to the cache."""
|
| 26 |
+
|
| 27 |
+
_index: int
|
| 28 |
+
_block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
|
| 29 |
+
|
| 30 |
+
@abstractmethod
|
| 31 |
+
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
|
| 32 |
+
"""Allocates n_blocks for a given request_id. Returns the num of blocks allocated if successful and None
|
| 33 |
+
otherwise."""
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
def free_blocks(self, request_id: str, free_blocks: deque[int]) -> None:
|
| 37 |
+
"""Frees all blocks associated with a request_id."""
|
| 38 |
+
if request_id in self._block_table:
|
| 39 |
+
blocks_to_free = self._block_table.pop(request_id)
|
| 40 |
+
free_blocks.extend(blocks_to_free)
|
| 41 |
+
else:
|
| 42 |
+
logger.warning(
|
| 43 |
+
f"CacheAllocator {self._index} attempted to free blocks for non-existent request_id: {request_id}"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
@abstractmethod
|
| 47 |
+
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
| 48 |
+
"""Returns the physical indices of where to read request_id's cache in the cache tensor."""
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
@abstractmethod
|
| 52 |
+
def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
| 53 |
+
"""Returns the physical indices of where to write request_id's cache in the cache tensor."""
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
@abstractmethod
|
| 57 |
+
def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]:
|
| 58 |
+
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class FullAttentionCacheAllocator(CacheAllocator):
|
| 63 |
+
"""Cache manager for a group of full attention layers."""
|
| 64 |
+
|
| 65 |
+
def __init__(self, index: int, block_size: int) -> None:
|
| 66 |
+
"""Initializes the cache manager for a group of full attention layers.
|
| 67 |
+
Args:
|
| 68 |
+
- index: the index of the associated layer group
|
| 69 |
+
- block_size: the size of the blocks in the cache
|
| 70 |
+
"""
|
| 71 |
+
self._index = index
|
| 72 |
+
self.block_size = block_size
|
| 73 |
+
self._block_table = {}
|
| 74 |
+
|
| 75 |
+
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
|
| 76 |
+
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
|
| 77 |
+
otherwise. For group of full attention layers, we always allocate the number of requested blocks."""
|
| 78 |
+
if len(free_blocks) < n_blocks:
|
| 79 |
+
return None
|
| 80 |
+
if request_id not in self._block_table:
|
| 81 |
+
self._block_table[request_id] = []
|
| 82 |
+
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(n_blocks))
|
| 83 |
+
return n_blocks
|
| 84 |
+
|
| 85 |
+
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
| 86 |
+
"""Returns the physical indices of where to read request_id's cache. For a group of full attention layers, we
|
| 87 |
+
first write the new cache to the cache tensor and then read the entire cache from the beginning to the end."""
|
| 88 |
+
# Retrieve the block table for the request and raise an error if it doesn't exist
|
| 89 |
+
block_table = self._block_table.get(request_id)
|
| 90 |
+
if block_table is None:
|
| 91 |
+
raise ValueError(f"No block table found for request {request_id}")
|
| 92 |
+
# Compute the physical indices
|
| 93 |
+
physical_indices = []
|
| 94 |
+
for i in range(past_length + query_length):
|
| 95 |
+
block_idx = i // self.block_size
|
| 96 |
+
block_offset = i % self.block_size
|
| 97 |
+
physical_index = block_table[block_idx] * self.block_size + block_offset
|
| 98 |
+
physical_indices.append(physical_index)
|
| 99 |
+
return physical_indices
|
| 100 |
+
|
| 101 |
+
def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
| 102 |
+
"""Returns the physical indices for writing to the cache. For a group of full attention layers, we write the new
|
| 103 |
+
cache as a continuation of the existing cache for the same request."""
|
| 104 |
+
block_table = self._block_table.get(request_id)
|
| 105 |
+
if block_table is None:
|
| 106 |
+
raise ValueError(f"No block table found for request {request_id}")
|
| 107 |
+
# Compute the physical indices
|
| 108 |
+
physical_indices = []
|
| 109 |
+
for i in range(past_length, past_length + query_length):
|
| 110 |
+
block_idx = i // self.block_size
|
| 111 |
+
block_offset = i % self.block_size
|
| 112 |
+
physical_index = block_table[block_idx] * self.block_size + block_offset
|
| 113 |
+
physical_indices.append(physical_index)
|
| 114 |
+
return physical_indices
|
| 115 |
+
|
| 116 |
+
def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]:
|
| 117 |
+
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
|
| 118 |
+
seqlens_k = past_length + query_length
|
| 119 |
+
return "full_attention", seqlens_k
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class SlidingAttentionCacheAllocator(CacheAllocator):
|
| 123 |
+
"""Cache manager for sliding window attention layers."""
|
| 124 |
+
|
| 125 |
+
def __init__(self, index: int, block_size: int, sliding_window: int) -> None:
|
| 126 |
+
"""Initializes the cache manager for a group of sliding window attention layers.
|
| 127 |
+
Args:
|
| 128 |
+
- index: the index of the associated layer group
|
| 129 |
+
- block_size: the size of the blocks in the cache
|
| 130 |
+
- sliding_window: the size of the sliding window
|
| 131 |
+
"""
|
| 132 |
+
self._index = index
|
| 133 |
+
self.block_size = block_size
|
| 134 |
+
self.sliding_window = sliding_window
|
| 135 |
+
self._max_blocks_per_request = ceil(self.sliding_window / self.block_size)
|
| 136 |
+
self._block_table = {}
|
| 137 |
+
|
| 138 |
+
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
|
| 139 |
+
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
|
| 140 |
+
otherwise. For group of sliding window attention layers, we only allocate up to the point where we can fit an
|
| 141 |
+
entire sliding window in the cache tensor."""
|
| 142 |
+
if request_id not in self._block_table:
|
| 143 |
+
self._block_table[request_id] = []
|
| 144 |
+
# Early return if we are already at the max number of blocks per request
|
| 145 |
+
already_allocated = len(self._block_table[request_id])
|
| 146 |
+
if already_allocated == self._max_blocks_per_request:
|
| 147 |
+
return 0
|
| 148 |
+
# Compute actual number of blocks to allocate
|
| 149 |
+
after_allocation = min(already_allocated + n_blocks, self._max_blocks_per_request)
|
| 150 |
+
actual_n_blocks = after_allocation - already_allocated
|
| 151 |
+
# Classic allocation
|
| 152 |
+
if len(free_blocks) < actual_n_blocks:
|
| 153 |
+
return None
|
| 154 |
+
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(actual_n_blocks))
|
| 155 |
+
return actual_n_blocks
|
| 156 |
+
|
| 157 |
+
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
| 158 |
+
"""Returns the physical indices of where to read request_id's cache in the cache tensor.
|
| 159 |
+
For a group of sliding window attention layers, we read from the cache tensor before writing on it, because the
|
| 160 |
+
new cache can overwrite the old one. To form the cache + new key / values states, we read the at most
|
| 161 |
+
sliding_window - 1 cache page and then manually add the new key / values states after. Hence the -1 indices
|
| 162 |
+
which indicate where to store the new key or values indices."""
|
| 163 |
+
# Retrieve the block table for the request and raise an error if it doesn't exist
|
| 164 |
+
block_table = self._block_table.get(request_id)
|
| 165 |
+
if block_table is None:
|
| 166 |
+
raise ValueError(f"No block table found for request {request_id}")
|
| 167 |
+
# Apply sliding window
|
| 168 |
+
start_index = 0 if past_length < self.sliding_window else past_length % self.sliding_window
|
| 169 |
+
cache_length = min(past_length, self.sliding_window - 1)
|
| 170 |
+
# Compute the physical indices
|
| 171 |
+
physical_indices = []
|
| 172 |
+
for i in range(start_index, start_index + cache_length):
|
| 173 |
+
i %= self.sliding_window
|
| 174 |
+
block_idx = i // self.block_size
|
| 175 |
+
block_offset = i % self.block_size
|
| 176 |
+
physical_index = block_table[block_idx] * self.block_size + block_offset
|
| 177 |
+
physical_indices.append(physical_index)
|
| 178 |
+
return physical_indices + [-1] * query_length
|
| 179 |
+
|
| 180 |
+
def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
| 181 |
+
"""Returns the physical indices of where to write request_id's cache in the cache tensor. For a group of
|
| 182 |
+
sliding window attention layers, we write the new cache in rolling-buffer kind of way: if we reach the end of
|
| 183 |
+
the allocated physical cache, we start writing from the beginning of the physical cache again."""
|
| 184 |
+
# Retrieve the block table for the request and raise an error if it doesn't exist
|
| 185 |
+
block_table = self._block_table.get(request_id)
|
| 186 |
+
if block_table is None:
|
| 187 |
+
raise ValueError(f"No block table found for request {request_id}")
|
| 188 |
+
# Apply sliding window
|
| 189 |
+
start_index = past_length % self.sliding_window
|
| 190 |
+
cache_length = min(query_length, self.sliding_window)
|
| 191 |
+
padding_length = query_length - cache_length
|
| 192 |
+
# Compute the physical indices
|
| 193 |
+
physical_indices = []
|
| 194 |
+
for i in range(start_index, start_index + cache_length):
|
| 195 |
+
i %= self.sliding_window
|
| 196 |
+
block_idx = i // self.block_size
|
| 197 |
+
block_offset = i % self.block_size
|
| 198 |
+
physical_index = block_table[block_idx] * self.block_size + block_offset
|
| 199 |
+
physical_indices.append(physical_index)
|
| 200 |
+
if padding_length > 0:
|
| 201 |
+
physical_indices = [-1] * padding_length + physical_indices
|
| 202 |
+
return physical_indices
|
| 203 |
+
|
| 204 |
+
def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]:
|
| 205 |
+
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
|
| 206 |
+
seqlens_k = query_length + min(past_length, self.sliding_window - 1)
|
| 207 |
+
return "sliding_attention", seqlens_k
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# TODO: test the impact of this
|
| 211 |
+
# def get_read_indices(self, request_id: str, past_length: int) -> list[int]:
|
| 212 |
+
# # Retrieve the block table for the request and raise an error if it doesn't exist
|
| 213 |
+
# block_table = self._block_table.get(request_id)
|
| 214 |
+
# if block_table is None:
|
| 215 |
+
# raise ValueError(f"No block table found for request {request_id}")
|
| 216 |
+
# # Compute the physical indices
|
| 217 |
+
# physical_indices = []
|
| 218 |
+
# n_left = past_length
|
| 219 |
+
# for block_idx in block_table:
|
| 220 |
+
# block_physical_index = block_idx * self.block_size
|
| 221 |
+
# pages_used = min(self.block_size, n_left)
|
| 222 |
+
# physical_indices.extend(block_physical_index + i for i in range(pages_used))
|
| 223 |
+
# n_left -= pages_used
|
| 224 |
+
# if n_left == 0:
|
| 225 |
+
# return physical_indices
|
| 226 |
+
# raise ValueError(f"Request {request_id} required too many indices: {past_length = } and {len(block_table) = }")
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/continuous_api.py
ADDED
|
@@ -0,0 +1,1047 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import queue
|
| 17 |
+
import threading
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from functools import partial
|
| 20 |
+
from itertools import count
|
| 21 |
+
from time import perf_counter
|
| 22 |
+
from typing import Optional, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from torch import nn
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
|
| 28 |
+
from ...configuration_utils import PretrainedConfig
|
| 29 |
+
from ...generation.configuration_utils import GenerationConfig
|
| 30 |
+
from ...utils.logging import logging
|
| 31 |
+
from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
|
| 32 |
+
from .cache import PagedAttentionCache
|
| 33 |
+
from .requests import GenerationOutput, RequestState, RequestStatus, get_device_and_memory_breakdown, logger
|
| 34 |
+
from .scheduler import SCHEDULER_MAPPING, FIFOScheduler, Scheduler
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def build_attention_mask(
|
| 38 |
+
attention_mask: torch.Tensor,
|
| 39 |
+
cumulative_seqlens_q: torch.Tensor,
|
| 40 |
+
cumulative_seqlens_k: torch.Tensor,
|
| 41 |
+
sliding_window: int = 1,
|
| 42 |
+
) -> None:
|
| 43 |
+
"""Builds an attention mask inplace using the cumulative seqlens of the query and key. If given a sliding window, it
|
| 44 |
+
will also apply a sliding window mask on top. The attention mask is not boolean, it uses zeroes and -inf (or its
|
| 45 |
+
equivalent) so it's more of an attention score bias tensor.
|
| 46 |
+
The attention mask is a block-diagonal matrix, with each block an attention mask for a single query-key pair.
|
| 47 |
+
Each of those block is built from a causal mask and, if there is a sliding window, a sliding window mask.
|
| 48 |
+
|
| 49 |
+
An example is represented below, with seqlen_k = 8, seqlen_q = 4 and sliding_window = 6:
|
| 50 |
+
|
| 51 |
+
CAUSAL MASK:
|
| 52 |
+
|
| 53 |
+
█ █ █ █ █ ░ ░ ░
|
| 54 |
+
█ █ █ █ █ █ ░ ░
|
| 55 |
+
█ █ █ █ █ █ █ ░
|
| 56 |
+
█ █ █ █ █ █ █ █
|
| 57 |
+
|
| 58 |
+
SLIDING WINDOW MASK:
|
| 59 |
+
┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 8 - 4 - 6 = -2 offset to the right
|
| 60 |
+
<─┴─>
|
| 61 |
+
░ █ | █ █ █ █ █ █ █ █
|
| 62 |
+
░ ░ | █ █ █ █ █ █ █ █
|
| 63 |
+
░ ░ | ░ █ █ █ █ █ █ █
|
| 64 |
+
░ ░ | ░ ░ █ █ █ █ █ █
|
| 65 |
+
|
| 66 |
+
ATTENTION MASK (sum of causal and sliding window masks):
|
| 67 |
+
|
| 68 |
+
█ █ █ █ █ ░ ░ ░
|
| 69 |
+
█ █ █ █ █ █ ░ ░
|
| 70 |
+
░ █ █ █ █ █ █ ░
|
| 71 |
+
░ ░ █ █ █ █ █ █
|
| 72 |
+
|
| 73 |
+
Another example with seqlen_k = 5, seqlen_q = 3 and sliding_window = 2:
|
| 74 |
+
|
| 75 |
+
CAUSAL MASK:
|
| 76 |
+
|
| 77 |
+
█ █ █ ░ ░
|
| 78 |
+
█ █ █ █ ░
|
| 79 |
+
█ █ █ █ █
|
| 80 |
+
|
| 81 |
+
SLIDING WINDOW MASK:
|
| 82 |
+
┌──────────────────────── seqlen_k - seqlen_q - sliding_window = 5 - 3 - 2 = 0 offset to the right
|
| 83 |
+
<┴>
|
| 84 |
+
| ░ █ █ █ █
|
| 85 |
+
| ░ ░ █ █ █
|
| 86 |
+
| ░ ░ ░ █ █
|
| 87 |
+
|
| 88 |
+
ATTENTION MASK (sum of causal and sliding window masks):
|
| 89 |
+
|
| 90 |
+
░ █ █ ░ ░
|
| 91 |
+
░ ░ █ █ ░
|
| 92 |
+
░ ░ ░ █ █
|
| 93 |
+
|
| 94 |
+
"""
|
| 95 |
+
min_value = torch.finfo(attention_mask.dtype).min
|
| 96 |
+
for i in range(len(cumulative_seqlens_q) - 1):
|
| 97 |
+
seqlen_q = cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]
|
| 98 |
+
seqlen_k = cumulative_seqlens_k[i + 1] - cumulative_seqlens_k[i]
|
| 99 |
+
if seqlen_q < seqlen_k and seqlen_q >= 1:
|
| 100 |
+
causal_diagonal = seqlen_k - seqlen_q + 1
|
| 101 |
+
else:
|
| 102 |
+
causal_diagonal = 1
|
| 103 |
+
query_range = slice(cumulative_seqlens_q[i], cumulative_seqlens_q[i + 1])
|
| 104 |
+
key_range = slice(cumulative_seqlens_k[i], cumulative_seqlens_k[i + 1])
|
| 105 |
+
# Apply causal mask
|
| 106 |
+
minus_inf = torch.full(
|
| 107 |
+
attention_mask[..., query_range, key_range].shape,
|
| 108 |
+
min_value,
|
| 109 |
+
dtype=attention_mask.dtype,
|
| 110 |
+
device=attention_mask.device,
|
| 111 |
+
)
|
| 112 |
+
masked = torch.triu(minus_inf, diagonal=causal_diagonal)
|
| 113 |
+
# Apply sliding window mask if needed
|
| 114 |
+
if sliding_window > 1:
|
| 115 |
+
sliding_diagonal = seqlen_k - seqlen_q - sliding_window
|
| 116 |
+
masked += torch.tril(minus_inf, diagonal=sliding_diagonal)
|
| 117 |
+
# Replace in attention mask
|
| 118 |
+
attention_mask[..., query_range, key_range] = masked
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@dataclass
|
| 122 |
+
class PagedAttentionArgs:
|
| 123 |
+
input_ids: torch.Tensor
|
| 124 |
+
attention_mask: Optional[torch.Tensor]
|
| 125 |
+
position_ids: torch.Tensor
|
| 126 |
+
cumulative_seqlens_q: torch.Tensor
|
| 127 |
+
cumulative_seqlens_k: torch.Tensor
|
| 128 |
+
max_seqlen_q: int
|
| 129 |
+
max_seqlen_k: int
|
| 130 |
+
write_index: list[torch.Tensor]
|
| 131 |
+
read_index: list[torch.Tensor]
|
| 132 |
+
logits_indices: torch.Tensor
|
| 133 |
+
cache: PagedAttentionCache
|
| 134 |
+
use_cache: bool = False
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# Continuous Batch Processor (Internal Logic)
|
| 138 |
+
@attach_tracer()
|
| 139 |
+
class ContinuousBatchProcessor:
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
cache: PagedAttentionCache,
|
| 143 |
+
config: PretrainedConfig,
|
| 144 |
+
generation_config: GenerationConfig,
|
| 145 |
+
input_queue: queue.Queue,
|
| 146 |
+
output_queue: queue.Queue,
|
| 147 |
+
stop_event: threading.Event,
|
| 148 |
+
model_device: torch.device,
|
| 149 |
+
model_dtype: torch.dtype,
|
| 150 |
+
scheduler: Scheduler,
|
| 151 |
+
streaming: bool = False,
|
| 152 |
+
manual_eviction: bool = False,
|
| 153 |
+
slice_inputs: bool = True, # TODO: There should be an heuristic to decide on slicing, compile, cuda graphs...
|
| 154 |
+
) -> None:
|
| 155 |
+
"""Initialize the continuous batch processor.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
cache: A [`PagedAttentionCache`] object
|
| 159 |
+
config: The model configuration
|
| 160 |
+
generation_config: The generation configuration
|
| 161 |
+
input_queue: Queue for incoming requests
|
| 162 |
+
output_queue: Queue for outgoing results
|
| 163 |
+
stop_event: Event to signal processing should stop
|
| 164 |
+
model_device: Device for model inputs/outputs
|
| 165 |
+
model_dtype: Data type for model inputs/outputs
|
| 166 |
+
scheduler: The [`Scheduler`] to use
|
| 167 |
+
streaming: Whether to stream tokens as they're generated
|
| 168 |
+
manual_eviction: Whether to manually evict blocks from the cache
|
| 169 |
+
slice_inputs: Whether to slice the inputs to the model
|
| 170 |
+
"""
|
| 171 |
+
self.cache = cache
|
| 172 |
+
self.config = config
|
| 173 |
+
self.generation_config = generation_config
|
| 174 |
+
self.input_queue = input_queue
|
| 175 |
+
self.output_queue = output_queue
|
| 176 |
+
self.stop_event = stop_event
|
| 177 |
+
self.model_device = model_device
|
| 178 |
+
self.model_dtype = model_dtype
|
| 179 |
+
self.scheduler = scheduler
|
| 180 |
+
self.streaming = streaming
|
| 181 |
+
self.manual_eviction = manual_eviction
|
| 182 |
+
self.slice_inputs = slice_inputs
|
| 183 |
+
|
| 184 |
+
# Retrieve the size of the sliding window if there is one
|
| 185 |
+
self.sliding_window = 1 if getattr(config, "sliding_window", None) is None else config.sliding_window
|
| 186 |
+
|
| 187 |
+
self.requests_in_batch: list[RequestState] = []
|
| 188 |
+
|
| 189 |
+
# Set up metrics collector
|
| 190 |
+
self.max_batch_tokens = cache.max_batch_tokens
|
| 191 |
+
self.metrics = ContinuousBatchProcessorMetrics(cache.max_batch_tokens)
|
| 192 |
+
|
| 193 |
+
# Setup static tensors
|
| 194 |
+
self.total_query_length = 0
|
| 195 |
+
self.total_key_length = 0
|
| 196 |
+
self.total_batch_size = 0
|
| 197 |
+
self.setup_static_tensors(cache.num_groups)
|
| 198 |
+
|
| 199 |
+
@traced(standalone=True)
|
| 200 |
+
def setup_static_tensors(self, num_groups: int) -> None:
|
| 201 |
+
T = self.max_batch_tokens
|
| 202 |
+
num_pages = self.cache.num_blocks * self.cache.block_size
|
| 203 |
+
self.tensor_metadata = {"dtype": torch.int32, "device": self.model_device}
|
| 204 |
+
|
| 205 |
+
# Some tensors always have the same shape regardless of the model
|
| 206 |
+
self.input_ids = torch.empty((1, T), **self.tensor_metadata)
|
| 207 |
+
self.position_ids = torch.empty((1, T), **self.tensor_metadata)
|
| 208 |
+
self.cumulative_seqlens_q = torch.empty((T + 1,), **self.tensor_metadata)
|
| 209 |
+
self.max_seqlen_q = 0
|
| 210 |
+
self.logits_indices = torch.empty((T,), **self.tensor_metadata)
|
| 211 |
+
self.output_ids = torch.empty((1, T), **self.tensor_metadata)
|
| 212 |
+
|
| 213 |
+
# For some kwargs, we have a dict of tensors with as many items as there are attention types
|
| 214 |
+
layer_types = getattr(self.config, "layer_types", None)
|
| 215 |
+
if layer_types is None:
|
| 216 |
+
sliding_window = getattr(self.config, "sliding_window", 1)
|
| 217 |
+
layer_types = ["full_attention"] if sliding_window in [1, None] else ["sliding_attention"]
|
| 218 |
+
layer_types = list(set(layer_types))
|
| 219 |
+
|
| 220 |
+
self.cumulative_seqlens_k = {
|
| 221 |
+
layer_type: torch.empty((T + 1), **self.tensor_metadata) for layer_type in layer_types
|
| 222 |
+
}
|
| 223 |
+
self.max_seqlen_k = dict.fromkeys(layer_types, 0)
|
| 224 |
+
|
| 225 |
+
if self.return_attention_mask():
|
| 226 |
+
attn_mask_kwargs = {
|
| 227 |
+
"size": (1, 1, T, num_pages + T),
|
| 228 |
+
"dtype": self.model_dtype,
|
| 229 |
+
"device": self.model_device,
|
| 230 |
+
}
|
| 231 |
+
self.attention_mask = {layer_type: torch.empty(**attn_mask_kwargs) for layer_type in layer_types}
|
| 232 |
+
else:
|
| 233 |
+
self.attention_mask = None
|
| 234 |
+
|
| 235 |
+
# For other kwargs, we need a list of tensors with as many tensors as there are groups
|
| 236 |
+
self.write_index_storage = [torch.empty((T,), **self.tensor_metadata) for _ in range(num_groups)]
|
| 237 |
+
self.read_index_storage = [torch.empty((num_pages + T), **self.tensor_metadata) for _ in range(num_groups)]
|
| 238 |
+
# For read index, the +T is because there are -1 for seqlen_q when model uses a sliding window
|
| 239 |
+
|
| 240 |
+
# After allocating empty tensors, we reset them to the right value
|
| 241 |
+
self.reset_static_tensors(full_reset=True)
|
| 242 |
+
|
| 243 |
+
def return_attention_mask(self) -> bool:
|
| 244 |
+
return self.config._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call
|
| 245 |
+
|
| 246 |
+
@traced
|
| 247 |
+
@torch.no_grad()
|
| 248 |
+
def reset_static_tensors(self, full_reset: bool = False):
|
| 249 |
+
"""Reset static tensors for the next batch. In between batches, reset only the parts that were used in the last
|
| 250 |
+
batch, but for initialisation, we can reset everything using the (full_reset) flag."""
|
| 251 |
+
# Compute the slice to reset
|
| 252 |
+
if full_reset or not self.slice_inputs:
|
| 253 |
+
q_len = self.write_index_storage[0].size(-1)
|
| 254 |
+
k_len = self.read_index_storage[0].size(-1)
|
| 255 |
+
b_size = self.write_index_storage[0].size(0)
|
| 256 |
+
else:
|
| 257 |
+
q_len = self.total_query_length
|
| 258 |
+
k_len = self.total_key_length
|
| 259 |
+
b_size = self.total_batch_size
|
| 260 |
+
|
| 261 |
+
# Reset the attributes that always have the same shape
|
| 262 |
+
self.input_ids[:, :q_len].zero_()
|
| 263 |
+
self.position_ids[:, :q_len].zero_()
|
| 264 |
+
self.cumulative_seqlens_q[: b_size + 1].zero_()
|
| 265 |
+
self.max_seqlen_q = 0
|
| 266 |
+
self.logits_indices[:q_len].fill_(-1)
|
| 267 |
+
self.output_ids[:, :q_len].fill_(-1)
|
| 268 |
+
|
| 269 |
+
# Reset the attributes that are either tensors or dict of tensors
|
| 270 |
+
for layer_type in self.cumulative_seqlens_k:
|
| 271 |
+
self.cumulative_seqlens_k[layer_type][: b_size + 1].zero_()
|
| 272 |
+
self.max_seqlen_k[layer_type] = 0
|
| 273 |
+
if self.attention_mask is not None:
|
| 274 |
+
self.attention_mask[layer_type][:, :, :q_len, :k_len].fill_(torch.finfo(self.model_dtype).min)
|
| 275 |
+
|
| 276 |
+
# Reset the attributes that are lists of tensors
|
| 277 |
+
for i in range(self.cache.num_groups):
|
| 278 |
+
self.write_index_storage[i][:q_len].fill_(-1)
|
| 279 |
+
self.read_index_storage[i][: q_len + k_len].fill_(-1)
|
| 280 |
+
|
| 281 |
+
def get_model_kwargs(self) -> PagedAttentionArgs:
|
| 282 |
+
"""Get model keyword arguments for the current batch."""
|
| 283 |
+
# Compute the slice to return
|
| 284 |
+
q_len = self.total_query_length if self.slice_inputs else self.write_index_storage[0].size(-1)
|
| 285 |
+
b_size = self.total_batch_size if self.slice_inputs else self.cumulative_seqlens_q.size(-1) - 1
|
| 286 |
+
|
| 287 |
+
# Prepare the kwargs, the attributes that are either tensors or dict of tensors are initialized to empty dicts
|
| 288 |
+
kwargs = {
|
| 289 |
+
"input_ids": self.input_ids[:, :q_len],
|
| 290 |
+
"position_ids": self.position_ids[:, :q_len],
|
| 291 |
+
"cu_seq_lens_q": self.cumulative_seqlens_q[: b_size + 1],
|
| 292 |
+
"max_seqlen_q": self.max_seqlen_q,
|
| 293 |
+
"logits_indices": self.logits_indices[:q_len],
|
| 294 |
+
"cu_seq_lens_k": {},
|
| 295 |
+
"max_seqlen_k": {},
|
| 296 |
+
"attention_mask": {},
|
| 297 |
+
"read_index": self.read_index, # slicing is done during building
|
| 298 |
+
"write_index": self.write_index, # slicing is done during building
|
| 299 |
+
"cache": self.cache,
|
| 300 |
+
"use_cache": False,
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
# For the attributes that are dict of tensors, we replace the dict with a tensor if there is only one entry
|
| 304 |
+
layer_types = list(self.cumulative_seqlens_k.keys())
|
| 305 |
+
if len(layer_types) > 1:
|
| 306 |
+
for layer_type, seqlens_k in self.cumulative_seqlens_k.items():
|
| 307 |
+
kwargs["cu_seq_lens_k"][layer_type] = seqlens_k[: b_size + 1]
|
| 308 |
+
kwargs["max_seqlen_k"][layer_type] = self.max_seqlen_k[layer_type]
|
| 309 |
+
if self.attention_mask is not None:
|
| 310 |
+
k_len = seqlens_k[b_size] if self.slice_inputs else self.attention_mask[layer_type].size(-1)
|
| 311 |
+
kwargs["attention_mask"][layer_type] = self.attention_mask[layer_type][..., :q_len, :k_len]
|
| 312 |
+
else:
|
| 313 |
+
layer_type = layer_types[0]
|
| 314 |
+
kwargs["cu_seq_lens_k"] = self.cumulative_seqlens_k[layer_type][: b_size + 1]
|
| 315 |
+
kwargs["max_seqlen_k"] = self.max_seqlen_k[layer_type]
|
| 316 |
+
if self.attention_mask is not None:
|
| 317 |
+
k_len = self.cumulative_seqlens_k[layer_type][b_size]
|
| 318 |
+
k_len = k_len if self.slice_inputs else self.attention_mask[layer_type].size(-1)
|
| 319 |
+
kwargs["attention_mask"] = self.attention_mask[layer_type][..., :q_len, :k_len]
|
| 320 |
+
|
| 321 |
+
if self.attention_mask is None:
|
| 322 |
+
kwargs["attention_mask"] = None
|
| 323 |
+
return kwargs
|
| 324 |
+
|
| 325 |
+
def __repr__(self):
|
| 326 |
+
return (
|
| 327 |
+
f"ContinuousBatchProcessor(input_queue={self.input_queue}, output_queue={self.output_queue}, "
|
| 328 |
+
f"active_requests={self.scheduler.active_requests}, waiting_requests={self.scheduler.waiting_requests})"
|
| 329 |
+
+ self.get_model_kwargs().__repr__()
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
@traced
|
| 333 |
+
def _get_new_requests(self):
|
| 334 |
+
"""Pull new requests from the input queue and add to waiting list."""
|
| 335 |
+
while not self.input_queue.empty():
|
| 336 |
+
try:
|
| 337 |
+
state = self.input_queue.get_nowait()
|
| 338 |
+
if state is None: # Sentinel value
|
| 339 |
+
continue
|
| 340 |
+
self.scheduler.add_waiting_request(state)
|
| 341 |
+
|
| 342 |
+
except queue.Empty:
|
| 343 |
+
break
|
| 344 |
+
except Exception as e:
|
| 345 |
+
logger.error(f"Error processing new request: {e}", exc_info=True)
|
| 346 |
+
state: RequestState = locals().get("state")
|
| 347 |
+
if state is not None:
|
| 348 |
+
self._handle_request_error(e, state)
|
| 349 |
+
|
| 350 |
+
@traced
|
| 351 |
+
def _handle_request_error(self, error, state: RequestState):
|
| 352 |
+
"""Handle general request processing error."""
|
| 353 |
+
state.status = RequestStatus.FAILED
|
| 354 |
+
state.error = str(error)
|
| 355 |
+
|
| 356 |
+
# Include any generated tokens if this is an active request
|
| 357 |
+
if isinstance(state.request_id, str):
|
| 358 |
+
state.static_outputs = self.scheduler.get_active_request_static_outputs(state.request_id)
|
| 359 |
+
else:
|
| 360 |
+
state.static_outputs = []
|
| 361 |
+
|
| 362 |
+
self.metrics.record_request_completion(state.created_time, state.request_id)
|
| 363 |
+
self.output_queue.put(state.to_generation_output())
|
| 364 |
+
|
| 365 |
+
@traced
|
| 366 |
+
def prepare_next_batch(self) -> bool:
|
| 367 |
+
"""Prepare tensors and metadata for the next model forward pass. Returns True if there are requests to process,
|
| 368 |
+
False otherwise."""
|
| 369 |
+
|
| 370 |
+
# Get new requests from the queue, stop if there are no pending requests
|
| 371 |
+
self._get_new_requests()
|
| 372 |
+
self.scheduler.clear_cancelled_requests()
|
| 373 |
+
if not self.scheduler.has_pending_requests():
|
| 374 |
+
return False
|
| 375 |
+
self.metrics.record_queue_metrics(len(self.scheduler.active_requests), len(self.scheduler.waiting_requests))
|
| 376 |
+
|
| 377 |
+
# Schedule the next batch of requests, stop if there are no requests in the batch
|
| 378 |
+
self.requests_in_batch = self.scheduler.schedule_batch(self.max_batch_tokens)
|
| 379 |
+
if not self.requests_in_batch:
|
| 380 |
+
return False
|
| 381 |
+
self.metrics.record_batch_metrics(self.requests_in_batch)
|
| 382 |
+
|
| 383 |
+
# Reset the static tensors used for storage
|
| 384 |
+
self.reset_static_tensors() # TODO: with slice_inputs, this might be unnecessary
|
| 385 |
+
|
| 386 |
+
# Prepare accumulators
|
| 387 |
+
self.total_query_length = 0
|
| 388 |
+
self.total_key_length = 0
|
| 389 |
+
self.total_batch_size = 0
|
| 390 |
+
|
| 391 |
+
input_ids = []
|
| 392 |
+
position_ids = []
|
| 393 |
+
cumulative_seqlens_q = [0]
|
| 394 |
+
logits_indices = []
|
| 395 |
+
|
| 396 |
+
if isinstance(self.cumulative_seqlens_k, dict):
|
| 397 |
+
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
|
| 398 |
+
else:
|
| 399 |
+
cumulative_seqlens_k = [0]
|
| 400 |
+
|
| 401 |
+
read_index = [[] for _ in range(self.cache.num_groups)]
|
| 402 |
+
write_index = [[] for _ in range(self.cache.num_groups)]
|
| 403 |
+
|
| 404 |
+
# Go through all the requests in the batch
|
| 405 |
+
for state in self.requests_in_batch:
|
| 406 |
+
# First we retrieve the lengths related to the request
|
| 407 |
+
past_length = state.position_offset
|
| 408 |
+
query_length = len(state.prompt_ids)
|
| 409 |
+
seqlens_k = self.cache.get_seqlens_k(state.request_id, past_length, query_length)
|
| 410 |
+
|
| 411 |
+
# Then we update the total lengths that are used for slicing
|
| 412 |
+
self.total_query_length += query_length
|
| 413 |
+
# total_key_length is used to slice the keys so we need to take the max of all the key lengths
|
| 414 |
+
self.total_key_length += max(seqlens_k.values())
|
| 415 |
+
self.total_batch_size += 1
|
| 416 |
+
# And the attribute tracking the position in the request object
|
| 417 |
+
state.position_offset += query_length
|
| 418 |
+
|
| 419 |
+
# Then we accumulate for the object used in the kwargs
|
| 420 |
+
input_ids.extend(state.prompt_ids)
|
| 421 |
+
position_ids.extend(range(past_length, past_length + query_length))
|
| 422 |
+
cumulative_seqlens_q.append(cumulative_seqlens_q[-1] + query_length)
|
| 423 |
+
self.max_seqlen_q = max(self.max_seqlen_q, query_length)
|
| 424 |
+
|
| 425 |
+
if not state.remaining_prompt_ids:
|
| 426 |
+
logits_indices.append(cumulative_seqlens_q[-1] - 1)
|
| 427 |
+
|
| 428 |
+
for layer_type, layer_type_seqlen_k in seqlens_k.items():
|
| 429 |
+
cumulative_seqlens_k[layer_type].append(cumulative_seqlens_k[layer_type][-1] + layer_type_seqlen_k)
|
| 430 |
+
self.max_seqlen_k[layer_type] = max(self.max_seqlen_k[layer_type], layer_type_seqlen_k)
|
| 431 |
+
|
| 432 |
+
self.cache.extend_read_indices(state.request_id, past_length, query_length, read_index)
|
| 433 |
+
self.cache.extend_write_indices(state.request_id, past_length, query_length, write_index)
|
| 434 |
+
|
| 435 |
+
# When looping over request is done, we can build the actual tensors
|
| 436 |
+
self._build_tensors(
|
| 437 |
+
input_ids,
|
| 438 |
+
position_ids,
|
| 439 |
+
read_index,
|
| 440 |
+
write_index,
|
| 441 |
+
cumulative_seqlens_q,
|
| 442 |
+
cumulative_seqlens_k,
|
| 443 |
+
logits_indices,
|
| 444 |
+
)
|
| 445 |
+
self.metrics.record_kv_cache_memory_metrics(self.cache)
|
| 446 |
+
|
| 447 |
+
if logger.isEnabledFor(logging.DEBUG):
|
| 448 |
+
if isinstance(self.cumulative_seqlens_k, dict):
|
| 449 |
+
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
|
| 450 |
+
else:
|
| 451 |
+
ck = cumulative_seqlens_k[-1]
|
| 452 |
+
logger.debug(
|
| 453 |
+
f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, "
|
| 454 |
+
f"Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. "
|
| 455 |
+
f"cum KV: {ck}, free blocks: {self.cache.get_num_free_blocks()}"
|
| 456 |
+
)
|
| 457 |
+
return True
|
| 458 |
+
|
| 459 |
+
@traced
|
| 460 |
+
def _build_tensors(
|
| 461 |
+
self,
|
| 462 |
+
input_ids: list[int],
|
| 463 |
+
position_ids: list[int],
|
| 464 |
+
read_index: list[list[int]],
|
| 465 |
+
write_index: list[list[int]],
|
| 466 |
+
cumulative_seqlens_q: list[int],
|
| 467 |
+
cumulative_seqlens_k: Union[list[int], dict[str, list[int]]],
|
| 468 |
+
logits_indices: list[int],
|
| 469 |
+
) -> None:
|
| 470 |
+
"""Builds the actual tensors for the current batch, by modifying the already allocated tensors in place."""
|
| 471 |
+
to_tensor = partial(torch.tensor, **self.tensor_metadata)
|
| 472 |
+
|
| 473 |
+
# Those kwargs always have the same type regardless of the model
|
| 474 |
+
self.input_ids[:, : len(input_ids)] = to_tensor(input_ids)
|
| 475 |
+
self.position_ids[:, : len(position_ids)] = to_tensor(position_ids)
|
| 476 |
+
self.cumulative_seqlens_q[: len(cumulative_seqlens_q)] = to_tensor(cumulative_seqlens_q)
|
| 477 |
+
self.logits_indices[: len(logits_indices)] = to_tensor(logits_indices)
|
| 478 |
+
|
| 479 |
+
# Those kwargs are either dict of tensors or tensors, so we need to handle both cases
|
| 480 |
+
for layer_type, layer_type_seqlens_k in cumulative_seqlens_k.items():
|
| 481 |
+
self.cumulative_seqlens_k[layer_type][: len(layer_type_seqlens_k)] = to_tensor(layer_type_seqlens_k)
|
| 482 |
+
if self.attention_mask is not None:
|
| 483 |
+
build_attention_mask(
|
| 484 |
+
attention_mask=self.attention_mask[layer_type],
|
| 485 |
+
cumulative_seqlens_q=cumulative_seqlens_q,
|
| 486 |
+
cumulative_seqlens_k=layer_type_seqlens_k,
|
| 487 |
+
sliding_window=self.sliding_window if layer_type == "sliding_attention" else 1,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
# The index only contain references to the storage tensors, so we update the storage and their references
|
| 491 |
+
self.read_index = []
|
| 492 |
+
self.write_index = []
|
| 493 |
+
for i, group_read_indices, group_write_indices in zip(count(), read_index, write_index):
|
| 494 |
+
# Write in the actual tensors
|
| 495 |
+
self.read_index_storage[i][: len(group_read_indices)] = to_tensor(group_read_indices)
|
| 496 |
+
self.write_index_storage[i][: len(group_write_indices)] = to_tensor(group_write_indices)
|
| 497 |
+
# Slice to the right size
|
| 498 |
+
r = len(group_read_indices) if self.slice_inputs else self.read_index_storage[i].size(-1)
|
| 499 |
+
w = len(group_write_indices) if self.slice_inputs else self.write_index_storage[i].size(-1)
|
| 500 |
+
# Add to the index
|
| 501 |
+
self.read_index.append(self.read_index_storage[i][:r])
|
| 502 |
+
self.write_index.append(self.write_index_storage[i][:w])
|
| 503 |
+
|
| 504 |
+
@traced
|
| 505 |
+
def _sync(self):
|
| 506 |
+
if self.output_ids is not None:
|
| 507 |
+
try:
|
| 508 |
+
out = self.output_ids.tolist()[0] # should be the only sync we do
|
| 509 |
+
except Exception:
|
| 510 |
+
out = [0, 1]
|
| 511 |
+
else:
|
| 512 |
+
out = [0, 0]
|
| 513 |
+
return out
|
| 514 |
+
|
| 515 |
+
@traced
|
| 516 |
+
def _maybe_send_output(self, state: RequestState, token: int):
|
| 517 |
+
"""Send output to the queue based on streaming mode and request state."""
|
| 518 |
+
if self.streaming:
|
| 519 |
+
self.output_queue.put(state.to_generation_output())
|
| 520 |
+
elif state.status == RequestStatus.FINISHED:
|
| 521 |
+
self.output_queue.put(state.to_generation_output())
|
| 522 |
+
|
| 523 |
+
@traced
|
| 524 |
+
def update_batch(self):
|
| 525 |
+
"""Update request states based on generated tokens."""
|
| 526 |
+
out_tokens = self._sync()
|
| 527 |
+
finished_request_ids = []
|
| 528 |
+
for i, state in enumerate(self.requests_in_batch):
|
| 529 |
+
req_id = state.request_id
|
| 530 |
+
if len(state.remaining_prompt_ids) == 0:
|
| 531 |
+
self.metrics.record_ttft_metric(state.created_time, state.request_id)
|
| 532 |
+
state.status = RequestStatus.DECODING
|
| 533 |
+
token = out_tokens[self.logits_indices[i]]
|
| 534 |
+
state.prompt_ids = [token]
|
| 535 |
+
if state.update_with_token(token):
|
| 536 |
+
self.metrics.record_request_completion(state.created_time, state.request_id)
|
| 537 |
+
self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
|
| 538 |
+
finished_request_ids.append(req_id)
|
| 539 |
+
self._maybe_send_output(state, token)
|
| 540 |
+
elif state.status == RequestStatus.PREFILLING_SPLIT:
|
| 541 |
+
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
|
| 542 |
+
if self.cache.get_num_free_blocks() == 0:
|
| 543 |
+
raise ValueError("No more free blocks")
|
| 544 |
+
|
| 545 |
+
@traced
|
| 546 |
+
def has_pending_requests(self) -> bool:
|
| 547 |
+
"""Check if there are any active or waiting requests."""
|
| 548 |
+
return self.scheduler.has_pending_requests()
|
| 549 |
+
|
| 550 |
+
@traced
|
| 551 |
+
def handle_batch_error(self, error):
|
| 552 |
+
"""Handle errors during batch processing."""
|
| 553 |
+
failed_reqs = self.requests_in_batch
|
| 554 |
+
for req in failed_reqs:
|
| 555 |
+
self._handle_request_error(error, req)
|
| 556 |
+
self.scheduler.finish_request(req.request_id)
|
| 557 |
+
|
| 558 |
+
@traced
|
| 559 |
+
def fail_all_requests(self, error):
|
| 560 |
+
"""Fail all active requests with the given error.
|
| 561 |
+
|
| 562 |
+
Args:
|
| 563 |
+
error: The error to report in the failure message
|
| 564 |
+
"""
|
| 565 |
+
|
| 566 |
+
requests = list(self.scheduler.active_requests.values())
|
| 567 |
+
for state in requests:
|
| 568 |
+
self._handle_request_error(error, state)
|
| 569 |
+
self.scheduler.finish_request(state.request_id)
|
| 570 |
+
|
| 571 |
+
# Also fail any requests in the waiting queue
|
| 572 |
+
for req_id in list(self.scheduler.waiting_requests.keys()):
|
| 573 |
+
state = self.scheduler.waiting_requests.pop(req_id)
|
| 574 |
+
self._handle_request_error(error, state)
|
| 575 |
+
|
| 576 |
+
# Clear the ordering queue
|
| 577 |
+
self.scheduler.waiting_requests_order.clear()
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
# Manager Class (User Interface)
|
| 581 |
+
@attach_tracer()
|
| 582 |
+
class ContinuousBatchingManager:
|
| 583 |
+
"""Manager for handling continuous batching of generation requests.
|
| 584 |
+
|
| 585 |
+
This class provides the user interface for submitting generation requests,
|
| 586 |
+
retrieving results, and managing the background generation thread.
|
| 587 |
+
"""
|
| 588 |
+
|
| 589 |
+
def __init__(
|
| 590 |
+
self,
|
| 591 |
+
model,
|
| 592 |
+
generation_config: GenerationConfig,
|
| 593 |
+
manual_eviction: bool = False,
|
| 594 |
+
max_queue_size=0,
|
| 595 |
+
streaming: bool = True,
|
| 596 |
+
slice_inputs: bool = True,
|
| 597 |
+
):
|
| 598 |
+
"""Initialize the continuous batching manager.
|
| 599 |
+
|
| 600 |
+
Args:
|
| 601 |
+
model: The language model for generation
|
| 602 |
+
generation_config: Configuration for generation parameters
|
| 603 |
+
max_queue_size: Maximum size of the request queue (0 = unlimited)
|
| 604 |
+
streaming: Whether to stream tokens as they are generated
|
| 605 |
+
"""
|
| 606 |
+
self.model = model.eval()
|
| 607 |
+
generation_config = model.generation_config if generation_config is None else generation_config
|
| 608 |
+
self.generation_config = generation_config
|
| 609 |
+
self.input_queue = queue.Queue(maxsize=max_queue_size)
|
| 610 |
+
self.output_queue = queue.Queue()
|
| 611 |
+
self.stop_event = threading.Event()
|
| 612 |
+
self.streaming = streaming
|
| 613 |
+
self.log_prob_generation = getattr(generation_config, "log_prob_generation", False)
|
| 614 |
+
self._generation_thread = None
|
| 615 |
+
self._request_counter = 0
|
| 616 |
+
self._request_lock = threading.Lock()
|
| 617 |
+
self.model.generation_config.top_p = None
|
| 618 |
+
self.do_sample = getattr(generation_config, "do_sample", True)
|
| 619 |
+
self.logit_processor = self.model._get_logits_processor(generation_config)
|
| 620 |
+
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", False) # TODO: same as do_sample
|
| 621 |
+
self.profile = getattr(generation_config, "profile", False)
|
| 622 |
+
self.manual_eviction = manual_eviction
|
| 623 |
+
self.batch_processor: Optional[ContinuousBatchProcessor] = None
|
| 624 |
+
self.slice_inputs = slice_inputs
|
| 625 |
+
|
| 626 |
+
if self.use_cuda_graph:
|
| 627 |
+
raise NotImplementedError("Cuda graphs are not supported yet")
|
| 628 |
+
|
| 629 |
+
@traced
|
| 630 |
+
def start(self):
|
| 631 |
+
"""Start the background generation thread."""
|
| 632 |
+
if self._generation_thread is not None and self._generation_thread.is_alive():
|
| 633 |
+
logger.warning("Manager thread is already running.")
|
| 634 |
+
return
|
| 635 |
+
|
| 636 |
+
self._result_queue = queue.Queue()
|
| 637 |
+
self._generation_thread = threading.Thread(target=self._run_generation_loop)
|
| 638 |
+
self._generation_thread.start()
|
| 639 |
+
|
| 640 |
+
def is_running(self):
|
| 641 |
+
"""Check if the background generation thread is running."""
|
| 642 |
+
return self._generation_thread is not None and self._generation_thread.is_alive()
|
| 643 |
+
|
| 644 |
+
def stop(self, block: bool = False, timeout: Optional[float] = None):
|
| 645 |
+
"""Signal the background thread to stop.
|
| 646 |
+
|
| 647 |
+
Args:
|
| 648 |
+
block: Whether to wait for the thread to stop
|
| 649 |
+
timeout: Maximum time to wait for the thread to stop
|
| 650 |
+
"""
|
| 651 |
+
if self._generation_thread is None:
|
| 652 |
+
logger.warning("Manager not started.")
|
| 653 |
+
return
|
| 654 |
+
|
| 655 |
+
if not self.stop_event.is_set():
|
| 656 |
+
self.stop_event.set()
|
| 657 |
+
logger.info("Stopping continuous batching manager...")
|
| 658 |
+
|
| 659 |
+
if block:
|
| 660 |
+
self.join(timeout)
|
| 661 |
+
|
| 662 |
+
def join(self, timeout: Optional[float] = None):
|
| 663 |
+
"""Wait for the background thread to finish.
|
| 664 |
+
|
| 665 |
+
Args:
|
| 666 |
+
timeout: Maximum time to wait for the thread to stop
|
| 667 |
+
"""
|
| 668 |
+
if self._generation_thread is not None:
|
| 669 |
+
self._generation_thread.join(timeout=timeout)
|
| 670 |
+
if self._generation_thread.is_alive():
|
| 671 |
+
logger.warning("Generation thread did not exit after join timeout.")
|
| 672 |
+
else:
|
| 673 |
+
logger.info("Continuous Batching Manager stopped.")
|
| 674 |
+
self._generation_thread = None
|
| 675 |
+
|
| 676 |
+
def add_request(
|
| 677 |
+
self, input_ids: list[int], request_id: Optional[str] = None, max_new_tokens: Optional[int] = None
|
| 678 |
+
) -> str:
|
| 679 |
+
"""Add a new generation request to the queue.
|
| 680 |
+
|
| 681 |
+
Args:
|
| 682 |
+
input_ids: Input token IDs to use as prompt
|
| 683 |
+
request_id: Optional custom request ID (auto-generated if None)
|
| 684 |
+
**kwargs: Additional generation parameters
|
| 685 |
+
|
| 686 |
+
Returns:
|
| 687 |
+
str: The request ID
|
| 688 |
+
"""
|
| 689 |
+
if request_id is None:
|
| 690 |
+
with self._request_lock:
|
| 691 |
+
request_id = f"req_{self._request_counter}"
|
| 692 |
+
self._request_counter += 1
|
| 693 |
+
|
| 694 |
+
max_new_tokens = self.generation_config.max_new_tokens if max_new_tokens is None else max_new_tokens
|
| 695 |
+
|
| 696 |
+
# NOTE: do we want to handle a case when the user wants token ids returned instead of decoded text?
|
| 697 |
+
state = RequestState(
|
| 698 |
+
request_id=request_id,
|
| 699 |
+
prompt_ids=list(input_ids),
|
| 700 |
+
full_prompt_ids=list(input_ids),
|
| 701 |
+
max_new_tokens=max_new_tokens,
|
| 702 |
+
eos_token_id=self.generation_config.eos_token_id,
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
# Use block=True with timeout to handle backpressure if queue is full
|
| 706 |
+
self.input_queue.put(state, block=True, timeout=10) # XXX: pass timeout as fn arg?
|
| 707 |
+
logger.debug(f"Added request {request_id} to queue.")
|
| 708 |
+
return request_id
|
| 709 |
+
|
| 710 |
+
def add_requests(self, inputs: list[list[int]], **kwargs):
|
| 711 |
+
for input_ids in inputs:
|
| 712 |
+
self.add_request(input_ids, **kwargs)
|
| 713 |
+
|
| 714 |
+
def cancel_request(self, request_id: str):
|
| 715 |
+
"""Cancel a request by its ID.
|
| 716 |
+
|
| 717 |
+
Args:
|
| 718 |
+
request_id: The ID of the request to cancel
|
| 719 |
+
"""
|
| 720 |
+
if self.batch_processor is not None:
|
| 721 |
+
self.batch_processor.scheduler.set_request_cancellation(request_id)
|
| 722 |
+
|
| 723 |
+
def get_result(self, request_id=None, timeout=None) -> Optional[GenerationOutput]:
|
| 724 |
+
"""Retrieve one result from the output queue.
|
| 725 |
+
|
| 726 |
+
Args:
|
| 727 |
+
timeout: Maximum time to wait for a result
|
| 728 |
+
|
| 729 |
+
Returns:
|
| 730 |
+
Optional[GenerationOutput]: The result data or None if timeout
|
| 731 |
+
"""
|
| 732 |
+
if self._generation_thread is None and self.output_queue.empty():
|
| 733 |
+
return None
|
| 734 |
+
try:
|
| 735 |
+
result = self.output_queue.get(block=True, timeout=timeout)
|
| 736 |
+
if request_id is not None and result.request_id != request_id:
|
| 737 |
+
self.output_queue.put(result)
|
| 738 |
+
return None
|
| 739 |
+
logger.debug(f"Retrieved result for request {result.request_id}")
|
| 740 |
+
return result
|
| 741 |
+
except queue.Empty:
|
| 742 |
+
return None
|
| 743 |
+
|
| 744 |
+
def __iter__(self):
|
| 745 |
+
"""Iterate over results as they become available."""
|
| 746 |
+
while self._generation_thread is not None and self._generation_thread.is_alive():
|
| 747 |
+
result = self.get_result(timeout=0.1)
|
| 748 |
+
if result is not None:
|
| 749 |
+
yield result
|
| 750 |
+
|
| 751 |
+
def request_id_iter(self, request_id):
|
| 752 |
+
"""Iterate over results matching a specific request id as they become available."""
|
| 753 |
+
request_cancelled = False
|
| 754 |
+
while self._generation_thread is not None and self._generation_thread.is_alive() and not request_cancelled:
|
| 755 |
+
result = self.get_result(request_id=request_id, timeout=0.1)
|
| 756 |
+
if result is not None:
|
| 757 |
+
yield result
|
| 758 |
+
if self.batch_processor is not None:
|
| 759 |
+
request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id)
|
| 760 |
+
|
| 761 |
+
@staticmethod
|
| 762 |
+
def supported_attention_implementations() -> set[str]:
|
| 763 |
+
return {"eager_paged", "sdpa_paged", "flash_attention_2"}
|
| 764 |
+
|
| 765 |
+
@staticmethod
|
| 766 |
+
def default_attention_implementation() -> str:
|
| 767 |
+
return "sdpa_paged"
|
| 768 |
+
|
| 769 |
+
@traced
|
| 770 |
+
def warmup(self, batch_processor):
|
| 771 |
+
stream = torch.cuda.Stream(device=self.model.device)
|
| 772 |
+
stream.wait_stream(torch.cuda.current_stream())
|
| 773 |
+
with torch.cuda.stream(stream):
|
| 774 |
+
# Warmup the model with a dummy forward pass
|
| 775 |
+
self._generation_step(batch_processor)
|
| 776 |
+
torch.cuda.current_stream().wait_stream(stream)
|
| 777 |
+
|
| 778 |
+
self.graph = torch.cuda.CUDAGraph()
|
| 779 |
+
with torch.cuda.graph(self.graph, stream=stream):
|
| 780 |
+
self._generation_step(batch_processor)
|
| 781 |
+
|
| 782 |
+
@traced
|
| 783 |
+
# @torch.compile
|
| 784 |
+
def _generation_step(self, batch_processor: ContinuousBatchProcessor):
|
| 785 |
+
"""Perform a single generation step. This is cuda graphed"""
|
| 786 |
+
batch_data = batch_processor.get_model_kwargs()
|
| 787 |
+
with torch.no_grad():
|
| 788 |
+
logits = self._model_forward(batch_data)
|
| 789 |
+
if self.log_prob_generation:
|
| 790 |
+
batch_processor.output_probs.copy_(logits) # TODO
|
| 791 |
+
probs = self._process_logit(batch_data, logits)
|
| 792 |
+
self._sample(batch_processor, probs)
|
| 793 |
+
|
| 794 |
+
@traced(span_name="model_forward")
|
| 795 |
+
def _model_forward(self, batch_data):
|
| 796 |
+
return self.model(**batch_data).logits
|
| 797 |
+
|
| 798 |
+
@traced(span_name="logit_processing")
|
| 799 |
+
def _process_logit(self, batch_data, logits):
|
| 800 |
+
# Pass continuous batching context to logits processor if it supports it. TODO we should find a way to make this a little bit cleaner!
|
| 801 |
+
if hasattr(self.logit_processor, "set_continuous_batching_context"):
|
| 802 |
+
self.logit_processor.set_continuous_batching_context(
|
| 803 |
+
batch_data["logits_indices"], batch_data["cu_seq_lens_q"]
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
# Handle shape compatibility: logit processors expect 2D tensors [batch_size, vocab_size]
|
| 807 |
+
# but continuous batching always produces 3D tensors [batch_size, seq_len, vocab_size]
|
| 808 |
+
batch_size, seq_len, vocab_size = logits.shape
|
| 809 |
+
logits_2d = logits.view(batch_size * seq_len, vocab_size)
|
| 810 |
+
input_ids_2d = batch_data["input_ids"].view(batch_size * seq_len)
|
| 811 |
+
|
| 812 |
+
# Process with 2D tensors
|
| 813 |
+
processed_logits_2d = self.logit_processor(input_ids_2d, logits_2d)
|
| 814 |
+
|
| 815 |
+
# Reshape back to 3D
|
| 816 |
+
return processed_logits_2d.view(batch_size, seq_len, vocab_size)
|
| 817 |
+
|
| 818 |
+
@traced(span_name="sampling")
|
| 819 |
+
def _sample(self, batch_processor: ContinuousBatchProcessor, probs):
|
| 820 |
+
if self.do_sample: # sample
|
| 821 |
+
probs = nn.functional.softmax(probs, dim=-1)
|
| 822 |
+
# probs[0] has shape [seq_len, vocab_size], multinomial returns [seq_len, 1]
|
| 823 |
+
next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(-1) # Now [seq_len]
|
| 824 |
+
# Add batch dimension back to match argmax output
|
| 825 |
+
next_tokens = next_tokens.unsqueeze(0) # Now [1, seq_len]
|
| 826 |
+
else:
|
| 827 |
+
next_tokens = torch.argmax(probs, dim=-1) # Already [1, seq_len]
|
| 828 |
+
|
| 829 |
+
tokens = next_tokens.size(1) # Get seq_len dimension
|
| 830 |
+
batch_processor.output_ids[:, :tokens].copy_(next_tokens)
|
| 831 |
+
|
| 832 |
+
def _run_generation_loop(self):
|
| 833 |
+
"""Main processing loop running in the background thread."""
|
| 834 |
+
batch_processor = None
|
| 835 |
+
try:
|
| 836 |
+
ref_time = perf_counter()
|
| 837 |
+
paged_attention_cache = PagedAttentionCache(
|
| 838 |
+
self.model.config,
|
| 839 |
+
self.generation_config,
|
| 840 |
+
self.model.device,
|
| 841 |
+
self.model.dtype,
|
| 842 |
+
tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting
|
| 843 |
+
)
|
| 844 |
+
logger.debug(f"PagedAttentionCache created in {perf_counter() - ref_time} seconds")
|
| 845 |
+
|
| 846 |
+
scheduler = None
|
| 847 |
+
if hasattr(self.generation_config, "scheduler"):
|
| 848 |
+
scheduler = SCHEDULER_MAPPING.get(self.generation_config.scheduler, None)
|
| 849 |
+
if scheduler is None:
|
| 850 |
+
logger.warning(f"Scheduler '{scheduler}' not found. Defaulting to FIFO.")
|
| 851 |
+
scheduler = FIFOScheduler
|
| 852 |
+
else:
|
| 853 |
+
# Default to fifo
|
| 854 |
+
scheduler = FIFOScheduler
|
| 855 |
+
|
| 856 |
+
ref_time = perf_counter()
|
| 857 |
+
batch_processor = ContinuousBatchProcessor(
|
| 858 |
+
paged_attention_cache,
|
| 859 |
+
self.model.config,
|
| 860 |
+
self.generation_config,
|
| 861 |
+
self.input_queue,
|
| 862 |
+
self.output_queue,
|
| 863 |
+
self.stop_event,
|
| 864 |
+
self.model.device,
|
| 865 |
+
self.model.dtype,
|
| 866 |
+
scheduler(paged_attention_cache, self.manual_eviction),
|
| 867 |
+
self.streaming,
|
| 868 |
+
self.manual_eviction,
|
| 869 |
+
slice_inputs=self.slice_inputs,
|
| 870 |
+
)
|
| 871 |
+
self.batch_processor = batch_processor
|
| 872 |
+
self.current_batch = 0
|
| 873 |
+
logger.debug(f"batch_processor created in {perf_counter() - ref_time} seconds")
|
| 874 |
+
while (not self.stop_event.is_set()) or batch_processor.has_pending_requests():
|
| 875 |
+
self._inner_generation_loop(batch_processor)
|
| 876 |
+
self.current_batch += 1
|
| 877 |
+
|
| 878 |
+
except Exception as e:
|
| 879 |
+
logger.error(f"Error in generation loop: {e}", exc_info=True)
|
| 880 |
+
self._handle_critical_error(e, batch_processor)
|
| 881 |
+
finally:
|
| 882 |
+
logger.info("Generation loop finished.")
|
| 883 |
+
|
| 884 |
+
@traced(span_name="generation_loop")
|
| 885 |
+
def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor):
|
| 886 |
+
if torch.cuda.is_available():
|
| 887 |
+
torch.cuda.synchronize()
|
| 888 |
+
if not batch_processor.prepare_next_batch():
|
| 889 |
+
return
|
| 890 |
+
if logger.level <= logging.DEBUG:
|
| 891 |
+
device, total, reserved, allocated = get_device_and_memory_breakdown()
|
| 892 |
+
logger.debug(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}")
|
| 893 |
+
if torch.cuda.is_available() and self.use_cuda_graph:
|
| 894 |
+
if self.current_batch == 0:
|
| 895 |
+
self.warmup(batch_processor)
|
| 896 |
+
elif hasattr(self, "graph"):
|
| 897 |
+
try:
|
| 898 |
+
self._graph_replay()
|
| 899 |
+
except Exception as e:
|
| 900 |
+
logger.error(f"Model forward pass failed: {e}", exc_info=True)
|
| 901 |
+
batch_processor.handle_batch_error(e)
|
| 902 |
+
return
|
| 903 |
+
else:
|
| 904 |
+
self._generation_step(batch_processor)
|
| 905 |
+
else:
|
| 906 |
+
self._generation_step(batch_processor)
|
| 907 |
+
if torch.cuda.is_available():
|
| 908 |
+
torch.cuda.synchronize()
|
| 909 |
+
batch_processor.update_batch()
|
| 910 |
+
|
| 911 |
+
@traced(span_name="graph_replay")
|
| 912 |
+
def _graph_replay(self):
|
| 913 |
+
self.graph.replay()
|
| 914 |
+
|
| 915 |
+
@traced
|
| 916 |
+
def _handle_critical_error(self, error, batch_processor: Optional[ContinuousBatchProcessor]):
|
| 917 |
+
"""Handle critical errors that terminate the generation loop."""
|
| 918 |
+
# Signal stop
|
| 919 |
+
self.stop_event.set()
|
| 920 |
+
|
| 921 |
+
# Fail pending requests in input queue
|
| 922 |
+
try:
|
| 923 |
+
while True:
|
| 924 |
+
req_data = self.input_queue.get_nowait()
|
| 925 |
+
if batch_processor is not None:
|
| 926 |
+
batch_processor._handle_request_error(error, req_data)
|
| 927 |
+
except queue.Empty:
|
| 928 |
+
pass
|
| 929 |
+
|
| 930 |
+
# Fail active requests
|
| 931 |
+
if batch_processor is not None:
|
| 932 |
+
batch_processor.fail_all_requests(error)
|
| 933 |
+
|
| 934 |
+
@traced
|
| 935 |
+
def evict_request_from_cache(self, request_id: str):
|
| 936 |
+
"""Evict a request from the cache. It is assumed that the request is already finished."""
|
| 937 |
+
if not self.manual_eviction:
|
| 938 |
+
raise RuntimeError("Manual eviction is not enabled for this manager.")
|
| 939 |
+
if self.batch_processor is not None:
|
| 940 |
+
self.batch_processor.scheduler.finish_request(request_id)
|
| 941 |
+
|
| 942 |
+
|
| 943 |
+
class ContinuousMixin:
|
| 944 |
+
"""Mixin class for models to add continuous batching capabilities."""
|
| 945 |
+
|
| 946 |
+
def init_continuous_batching(
|
| 947 |
+
self,
|
| 948 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 949 |
+
manual_eviction: bool = False,
|
| 950 |
+
max_queue_size: int = 0,
|
| 951 |
+
streaming: bool = False,
|
| 952 |
+
slice_inputs: bool = True,
|
| 953 |
+
) -> ContinuousBatchingManager:
|
| 954 |
+
"""Initialize a manager for continuous batching inference.
|
| 955 |
+
|
| 956 |
+
Args:
|
| 957 |
+
generation_config: Custom generation configuration
|
| 958 |
+
max_queue_size: Maximum size of the input request queue
|
| 959 |
+
streaming: Whether to stream tokens as they are generated
|
| 960 |
+
|
| 961 |
+
Returns:
|
| 962 |
+
`ContinuousBatchingManager`: The manager instance to add requests and retrieve results.
|
| 963 |
+
"""
|
| 964 |
+
if not hasattr(self, "config") or not hasattr(self, "device") or not hasattr(self, "dtype"):
|
| 965 |
+
raise AttributeError("Model must have 'config', 'device', and 'dtype' attributes.")
|
| 966 |
+
|
| 967 |
+
gen_config = generation_config if generation_config is not None else self.generation_config
|
| 968 |
+
if gen_config is None:
|
| 969 |
+
raise ValueError("A GenerationConfig must be provided or set in the model.")
|
| 970 |
+
|
| 971 |
+
if gen_config.eos_token_id is None:
|
| 972 |
+
logger.warning("`eos_token_id` not set in GenerationConfig. Setting to -1 (disabled).")
|
| 973 |
+
gen_config.eos_token_id = -1
|
| 974 |
+
|
| 975 |
+
# Create and return the manager
|
| 976 |
+
return ContinuousBatchingManager(
|
| 977 |
+
model=self,
|
| 978 |
+
generation_config=gen_config,
|
| 979 |
+
manual_eviction=manual_eviction,
|
| 980 |
+
max_queue_size=max_queue_size,
|
| 981 |
+
streaming=streaming,
|
| 982 |
+
slice_inputs=slice_inputs,
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
@traced
|
| 986 |
+
@torch.inference_mode()
|
| 987 |
+
def generate_batch(
|
| 988 |
+
self,
|
| 989 |
+
inputs: list[list[int]],
|
| 990 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 991 |
+
progress_bar: bool = True,
|
| 992 |
+
slice_inputs: bool = True,
|
| 993 |
+
**kwargs,
|
| 994 |
+
) -> list[list[int]]:
|
| 995 |
+
"""Generate sequences for a batch of prompts using continuous batching.
|
| 996 |
+
|
| 997 |
+
Args:
|
| 998 |
+
inputs: List of input token sequences (prompts)
|
| 999 |
+
generation_config: Optional generation configuration
|
| 1000 |
+
**kwargs: Additional generation parameters
|
| 1001 |
+
|
| 1002 |
+
Returns:
|
| 1003 |
+
`list[list[int]]`: A list containing the generated sequences (including prompt tokens
|
| 1004 |
+
if not handled otherwise) for each input prompt, in the same order.
|
| 1005 |
+
Returns an empty list `[]` for requests that failed.
|
| 1006 |
+
"""
|
| 1007 |
+
if not inputs:
|
| 1008 |
+
return []
|
| 1009 |
+
if logger.getEffectiveLevel() <= logging.DEBUG:
|
| 1010 |
+
logger.warning("Progress bar is disabled when logger level is less than DEBUG")
|
| 1011 |
+
progress_bar = False
|
| 1012 |
+
|
| 1013 |
+
# Initialize manager with the batch inputs
|
| 1014 |
+
manager = self.init_continuous_batching(generation_config=generation_config, slice_inputs=slice_inputs)
|
| 1015 |
+
manager.start()
|
| 1016 |
+
results = {}
|
| 1017 |
+
num_requests = len(inputs)
|
| 1018 |
+
try:
|
| 1019 |
+
from tqdm.contrib.logging import logging_redirect_tqdm
|
| 1020 |
+
|
| 1021 |
+
with logging_redirect_tqdm([logger]):
|
| 1022 |
+
with tqdm(
|
| 1023 |
+
total=num_requests,
|
| 1024 |
+
disable=(not progress_bar),
|
| 1025 |
+
desc=f"Solving {num_requests} requests",
|
| 1026 |
+
unit="request",
|
| 1027 |
+
) as pbar:
|
| 1028 |
+
manager.add_requests(inputs, **kwargs)
|
| 1029 |
+
finished_count = 0
|
| 1030 |
+
while finished_count < num_requests:
|
| 1031 |
+
result = manager.get_result(timeout=1)
|
| 1032 |
+
if result:
|
| 1033 |
+
req_id = result.request_id
|
| 1034 |
+
if result.status == RequestStatus.FINISHED:
|
| 1035 |
+
results[req_id] = result
|
| 1036 |
+
finished_count += 1
|
| 1037 |
+
pbar.update(1)
|
| 1038 |
+
else:
|
| 1039 |
+
if not manager.is_running():
|
| 1040 |
+
logger.error("Generation thread terminated unexpectedly.")
|
| 1041 |
+
break
|
| 1042 |
+
|
| 1043 |
+
except Exception as e:
|
| 1044 |
+
logger.error(f"Error during batch generation: {e}", exc_info=True)
|
| 1045 |
+
finally:
|
| 1046 |
+
manager.stop(block=True, timeout=5.0)
|
| 1047 |
+
return results
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/requests.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import time
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from enum import Enum
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from ...utils.logging import logging
|
| 23 |
+
from ...utils.metrics import traced
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# We centralize the logger here to coordinate between logging and progress bar
|
| 27 |
+
logger = logging.getLogger("ContinuousBatchingLogger")
|
| 28 |
+
# logger.setLevel(logging.INFO)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
|
| 32 |
+
if torch.cuda.is_available():
|
| 33 |
+
device = torch.device("cuda")
|
| 34 |
+
torch.cuda.empty_cache()
|
| 35 |
+
torch.cuda.synchronize()
|
| 36 |
+
total_memory = torch.cuda.get_device_properties(device).total_memory
|
| 37 |
+
reserved_memory = torch.cuda.memory_reserved(device)
|
| 38 |
+
allocated_memory = torch.cuda.memory_allocated(device)
|
| 39 |
+
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
| 40 |
+
device = torch.device("mps")
|
| 41 |
+
# MPS memory reporting (PyTorch 2.0+)
|
| 42 |
+
total_memory = torch.mps.driver_allocated_memory()
|
| 43 |
+
allocated_memory = total_memory - torch.mps.recommended_max_memory()
|
| 44 |
+
reserved_memory = 0 # MPS does not track reserved separately
|
| 45 |
+
else:
|
| 46 |
+
device = torch.device("cpu")
|
| 47 |
+
total_memory = None
|
| 48 |
+
reserved_memory = 0
|
| 49 |
+
allocated_memory = 0
|
| 50 |
+
return device, total_memory, reserved_memory, allocated_memory
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class RequestStatus(Enum):
|
| 54 |
+
"""Status of a generation request through its lifecycle."""
|
| 55 |
+
|
| 56 |
+
PENDING = "pending"
|
| 57 |
+
PREFILLING = "prefilling"
|
| 58 |
+
PREFILLING_SPLIT = "prefilling_split"
|
| 59 |
+
SPLIT_PENDING_REMAINDER = "split_pending_remainder"
|
| 60 |
+
DECODING = "decoding"
|
| 61 |
+
FINISHED = "finished"
|
| 62 |
+
FAILED = "failed"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclass
|
| 66 |
+
class GenerationOutput:
|
| 67 |
+
"""Tracks the output of a generation request.
|
| 68 |
+
|
| 69 |
+
Attributes:
|
| 70 |
+
request_id (str): The ID of the generation request.
|
| 71 |
+
prompt_ids (list[int]): The IDs of the prompt tokens.
|
| 72 |
+
generated_tokens (list[int]): The generated tokens.
|
| 73 |
+
logprobs (list[float]): The log probabilities of the generated tokens.
|
| 74 |
+
error (Optional[str]): Any error message associated with the request. When None, the request was successful.
|
| 75 |
+
status (RequestStatus): The status of the request.
|
| 76 |
+
created_time (float): The time the request was created.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
request_id: str
|
| 80 |
+
prompt_ids: list[int] = field(default_factory=list)
|
| 81 |
+
generated_tokens: list[int] = field(default_factory=list)
|
| 82 |
+
logprobs: list[float] = field(default_factory=list)
|
| 83 |
+
error: Optional[str] = None
|
| 84 |
+
status: RequestStatus = RequestStatus.PENDING
|
| 85 |
+
created_time: float = field(default_factory=time.time)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclass
|
| 89 |
+
class RequestState:
|
| 90 |
+
"""Tracks the state of a generation request through its lifecycle.
|
| 91 |
+
|
| 92 |
+
Attributes:
|
| 93 |
+
request_id (str): The ID of the generation request.
|
| 94 |
+
full_prompt_ids (list[int] | None): The tokens IDs of the full prompt.
|
| 95 |
+
prompt_ids (list[int] | None): The tokens IDs currently being processed.
|
| 96 |
+
remaining_prompt_ids (list[int]): The tokens IDs remaining to be processed (for split requests).
|
| 97 |
+
static_outputs (list[int]): The generated tokens.
|
| 98 |
+
allocated_blocks (int): The number of blocks allocated to the request.
|
| 99 |
+
position_offset (int): The current position in the sequence for position_ids.
|
| 100 |
+
status (RequestStatus): The status of the request: can be one of PENDING, PREFILLING, PREFILLING_SPLIT,
|
| 101 |
+
SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED
|
| 102 |
+
max_new_tokens (int): The maximum number of new tokens to generate.
|
| 103 |
+
eos_token_id (int): The ID of the end-of-sequence token.
|
| 104 |
+
created_time (float): The time the request was created.
|
| 105 |
+
error (Optional[str]): Any error message associated with the request. When None, has had no error yet.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
# Required fields
|
| 109 |
+
request_id: str
|
| 110 |
+
full_prompt_ids: Optional[list[int]] = None # Full initial prompt
|
| 111 |
+
prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated)
|
| 112 |
+
remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process
|
| 113 |
+
static_outputs: list[int] = field(default_factory=list) # Generated tokens
|
| 114 |
+
allocated_blocks: int = 0 # Number of blocks allocated to the request
|
| 115 |
+
position_offset: int = 0 # Current position in the sequence for position_ids
|
| 116 |
+
_status: RequestStatus = RequestStatus.PENDING # Status of the request, hidden behind a property
|
| 117 |
+
max_new_tokens: int = 20 # Maximum number of new tokens to generate
|
| 118 |
+
eos_token_id: int = -1 # ID of the end-of-sequence token
|
| 119 |
+
created_time: float = field(default_factory=time.time) # Time the request was created
|
| 120 |
+
error: Optional[str] = None # Error message if the request failed
|
| 121 |
+
lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished)
|
| 122 |
+
|
| 123 |
+
@property
|
| 124 |
+
def status(self) -> RequestStatus:
|
| 125 |
+
return self._status
|
| 126 |
+
|
| 127 |
+
@status.setter
|
| 128 |
+
def status(self, value: RequestStatus):
|
| 129 |
+
if self._status == RequestStatus.PENDING:
|
| 130 |
+
self.lifespan = (time.time(), -1)
|
| 131 |
+
elif value == RequestStatus.FINISHED:
|
| 132 |
+
self.lifespan = (self.lifespan[0], time.time())
|
| 133 |
+
self.log_end_of_request()
|
| 134 |
+
self._status = value
|
| 135 |
+
|
| 136 |
+
def log_end_of_request(self):
|
| 137 |
+
prefill_len = len(self.full_prompt_ids)
|
| 138 |
+
decode_len = self.generated_len()
|
| 139 |
+
start_time = self.lifespan[0] - self.created_time
|
| 140 |
+
end_time = self.lifespan[1] - self.created_time
|
| 141 |
+
logger.info(
|
| 142 |
+
f"Request {self.request_id} finished: {prefill_len = } {decode_len = } {start_time = } {end_time = }"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def current_len(self) -> int:
|
| 146 |
+
"""Get the current length of the sequence (prompt + generated tokens)."""
|
| 147 |
+
return self.position_offset
|
| 148 |
+
|
| 149 |
+
def generated_len(self) -> int:
|
| 150 |
+
"""Get the number of tokens generated so far."""
|
| 151 |
+
return len(self.static_outputs)
|
| 152 |
+
|
| 153 |
+
# TODO: this logic seems one token off, check it out
|
| 154 |
+
@traced
|
| 155 |
+
def update_with_token(self, token_id: int) -> bool:
|
| 156 |
+
"""Update the request with a newly generated token and check for completion.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
token_id: The token ID to add to the output sequence
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
bool: True if the request is now complete, False otherwise
|
| 163 |
+
"""
|
| 164 |
+
# Only update if we're in decoding state
|
| 165 |
+
if self.status != RequestStatus.DECODING:
|
| 166 |
+
return False
|
| 167 |
+
|
| 168 |
+
is_eos = token_id == self.eos_token_id and self.eos_token_id != -1
|
| 169 |
+
is_max_len = self.generated_len() >= self.max_new_tokens
|
| 170 |
+
|
| 171 |
+
# Only add the token if we're not finishing due to max length
|
| 172 |
+
# (EOS tokens should still be added to the output)
|
| 173 |
+
if not (is_max_len and not is_eos):
|
| 174 |
+
self.static_outputs.extend([token_id])
|
| 175 |
+
|
| 176 |
+
if is_eos or is_max_len:
|
| 177 |
+
self.status = RequestStatus.FINISHED
|
| 178 |
+
return True
|
| 179 |
+
return False
|
| 180 |
+
|
| 181 |
+
def __repr__(self):
|
| 182 |
+
msg = [
|
| 183 |
+
f"request_id={self.request_id}",
|
| 184 |
+
f"status={self._status}",
|
| 185 |
+
f"out_tokens={self.generated_len()}",
|
| 186 |
+
f"query_length={len(self.prompt_ids)}",
|
| 187 |
+
f"remaining_tokens={len(self.remaining_prompt_ids)}",
|
| 188 |
+
f"kv_length={self.position_offset}",
|
| 189 |
+
f"full_prompt_length={len(self.full_prompt_ids)}",
|
| 190 |
+
f"allocated_blocks={self.allocated_blocks}",
|
| 191 |
+
f"generated_tokens={self.static_outputs}",
|
| 192 |
+
]
|
| 193 |
+
return "RequestState(\n\t" + ",\n\t".join(msg) + "\n)"
|
| 194 |
+
|
| 195 |
+
def to_generation_output(self):
|
| 196 |
+
"""Convert the request state to a GenerationOutput object."""
|
| 197 |
+
return GenerationOutput(
|
| 198 |
+
request_id=self.request_id,
|
| 199 |
+
prompt_ids=self.full_prompt_ids,
|
| 200 |
+
status=self.status,
|
| 201 |
+
generated_tokens=self.static_outputs,
|
| 202 |
+
logprobs=[],
|
| 203 |
+
error=self.error,
|
| 204 |
+
)
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/generation/continuous_batching/scheduler.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import threading
|
| 16 |
+
from abc import ABC, abstractmethod
|
| 17 |
+
from collections import deque
|
| 18 |
+
|
| 19 |
+
from ...utils.metrics import attach_tracer, traced
|
| 20 |
+
from .cache import PagedAttentionCache
|
| 21 |
+
from .requests import RequestState, RequestStatus
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Scheduler(ABC):
|
| 25 |
+
"""
|
| 26 |
+
Abstract base class for scheduling requests in the continuous batch processor. Schedulers manage the lifecycle of
|
| 27 |
+
requests from when they are added to the waiting queue to when they are scheduled for processing. Different
|
| 28 |
+
schedulers implement different strategies for prioritizing and batching requests.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False):
|
| 32 |
+
self.active_requests: dict[str, RequestState] = {}
|
| 33 |
+
self.waiting_requests: dict[str, RequestState] = {}
|
| 34 |
+
self.waiting_requests_order: deque[str] = deque()
|
| 35 |
+
self.cache = cache
|
| 36 |
+
self.retain_cache_on_finish = retain_cache_on_finish
|
| 37 |
+
self._cancellation_lock = threading.Lock()
|
| 38 |
+
self._requests_to_cancel: set[str] = set()
|
| 39 |
+
|
| 40 |
+
@traced
|
| 41 |
+
def add_waiting_request(self, state: RequestState):
|
| 42 |
+
"""Adds a request to the waiting list."""
|
| 43 |
+
if self.retain_cache_on_finish and state.request_id in self.active_requests:
|
| 44 |
+
old_state = self.active_requests.pop(state.request_id)
|
| 45 |
+
state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :] # XXX: check for indexing error?
|
| 46 |
+
state.allocated_blocks = old_state.allocated_blocks
|
| 47 |
+
state.position_offset = old_state.position_offset
|
| 48 |
+
self.waiting_requests[state.request_id] = state
|
| 49 |
+
self.waiting_requests_order.append(state.request_id)
|
| 50 |
+
|
| 51 |
+
@abstractmethod
|
| 52 |
+
def schedule_batch(self, token_budget: int) -> list[RequestState]:
|
| 53 |
+
"""Schedules requests for the next batch based on available token budget. This method selects which requests
|
| 54 |
+
should be processed in the current batch, considering the token budget and the scheduler's prioritization rules.
|
| 55 |
+
The token_budget is the maximum number of tokens that can be processed in this batch."""
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
@traced
|
| 59 |
+
def has_pending_requests(self) -> bool:
|
| 60 |
+
"""Checks if there are requests ready to be processed."""
|
| 61 |
+
return len(self.active_requests) or len(self.waiting_requests)
|
| 62 |
+
|
| 63 |
+
@traced
|
| 64 |
+
def finish_request(self, request_id: str, evict_from_cache: bool = True):
|
| 65 |
+
"""Completes processing of a request and optionally frees its allocated cache blocks. This method is called
|
| 66 |
+
when a request has finished generation or encountered an error.
|
| 67 |
+
"""
|
| 68 |
+
if evict_from_cache:
|
| 69 |
+
self.cache.free_blocks(request_id)
|
| 70 |
+
if request_id in self.active_requests:
|
| 71 |
+
del self.active_requests[request_id]
|
| 72 |
+
|
| 73 |
+
@traced
|
| 74 |
+
def get_active_request_static_outputs(self, request_id: str) -> list[int]:
|
| 75 |
+
"""Gets generated tokens for an active request."""
|
| 76 |
+
if request_id in self.active_requests:
|
| 77 |
+
return self.active_requests[request_id].static_outputs
|
| 78 |
+
return []
|
| 79 |
+
|
| 80 |
+
@traced
|
| 81 |
+
def set_request_cancellation(self, request_id: str):
|
| 82 |
+
"""Marks a request for cancellation."""
|
| 83 |
+
with self._cancellation_lock:
|
| 84 |
+
self._requests_to_cancel.add(request_id)
|
| 85 |
+
|
| 86 |
+
@traced
|
| 87 |
+
def clear_cancelled_requests(self):
|
| 88 |
+
"""Remove all cancelled requests from active and waiting queues."""
|
| 89 |
+
with self._cancellation_lock:
|
| 90 |
+
for request_id in self._requests_to_cancel:
|
| 91 |
+
if request_id in self.active_requests:
|
| 92 |
+
del self.active_requests[request_id]
|
| 93 |
+
if request_id in self.waiting_requests:
|
| 94 |
+
del self.waiting_requests[request_id]
|
| 95 |
+
if request_id in self.waiting_requests_order:
|
| 96 |
+
self.waiting_requests_order.remove(request_id)
|
| 97 |
+
self.cache.free_blocks(request_id)
|
| 98 |
+
self._requests_to_cancel = set()
|
| 99 |
+
|
| 100 |
+
@traced
|
| 101 |
+
def request_is_cancelled(self, request_id: str) -> bool:
|
| 102 |
+
"""Checks if a request has been cancelled or removed."""
|
| 103 |
+
return request_id in self._requests_to_cancel or (
|
| 104 |
+
request_id not in self.active_requests and request_id not in self.waiting_requests
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
@traced
|
| 108 |
+
def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int) -> bool:
|
| 109 |
+
"""Allocate additional cache blocks for a request if the currently allocated blocks are insufficient to
|
| 110 |
+
accommodate the next tokens. It calculates how many blocks are needed based on the request's current
|
| 111 |
+
cache occupancy and the number of tokens to be processed. The allocation itself is done by the CacheAllocator
|
| 112 |
+
objects. Returns a boolean indicating if the allocation was successful or not.
|
| 113 |
+
"""
|
| 114 |
+
# 1. we check that the occupancy is less than the requested length
|
| 115 |
+
# 2. we allocate enough blocks to cover the requested length
|
| 116 |
+
current_len = state.current_len()
|
| 117 |
+
occupancy = state.allocated_blocks * self.cache.block_size - current_len
|
| 118 |
+
if occupancy < len_next_tokens or state.allocated_blocks == 0:
|
| 119 |
+
blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
|
| 120 |
+
allocated = self.cache.allocate_blocks(blocks_needed, state.request_id)
|
| 121 |
+
if allocated is None:
|
| 122 |
+
return False
|
| 123 |
+
state.allocated_blocks += allocated
|
| 124 |
+
return True
|
| 125 |
+
|
| 126 |
+
@traced(span_name="prepare_request")
|
| 127 |
+
def _prepare_request_for_processing(
|
| 128 |
+
self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str]
|
| 129 |
+
):
|
| 130 |
+
"""Prepares a request for processing in the current batch."""
|
| 131 |
+
request_tokens = (
|
| 132 |
+
state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids
|
| 133 |
+
)
|
| 134 |
+
if len(request_tokens) < token_budget:
|
| 135 |
+
# Can process the entire prompt/remainder
|
| 136 |
+
if state.status == RequestStatus.PENDING:
|
| 137 |
+
self.active_requests[state.request_id] = state
|
| 138 |
+
state.status = RequestStatus.PREFILLING
|
| 139 |
+
request_ids_to_remove_from_waiting.add(state.request_id)
|
| 140 |
+
elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
|
| 141 |
+
state.status = RequestStatus.PREFILLING
|
| 142 |
+
state.prompt_ids = state.remaining_prompt_ids
|
| 143 |
+
state.remaining_prompt_ids = []
|
| 144 |
+
else:
|
| 145 |
+
# Need to split the request
|
| 146 |
+
if state.status == RequestStatus.PENDING:
|
| 147 |
+
self.active_requests[state.request_id] = state
|
| 148 |
+
state.status = RequestStatus.PREFILLING_SPLIT
|
| 149 |
+
request_ids_to_remove_from_waiting.add(state.request_id)
|
| 150 |
+
elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
|
| 151 |
+
state.status = RequestStatus.PREFILLING_SPLIT
|
| 152 |
+
state.remaining_prompt_ids = request_tokens[token_budget:]
|
| 153 |
+
state.prompt_ids = request_tokens[:token_budget]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@attach_tracer()
|
| 157 |
+
class FIFOScheduler(Scheduler):
|
| 158 |
+
"""This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
|
| 159 |
+
prefilling requests. Additionally, it includes a safety margin mechanism to prevent cache exhaustion. By default,
|
| 160 |
+
when 80% of the cache is full, new requests will not be scheduled to prioritize decoding active requests."""
|
| 161 |
+
|
| 162 |
+
def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False, safety_margin: float = 0.2):
|
| 163 |
+
"""Initializes the FIFO scheduler. The safety margin is the percentage of free blocks under which we stop
|
| 164 |
+
scheduling new prefill requests, so safety_margin = 0.1 means that when there is less than 10% of free blocks,
|
| 165 |
+
or equivalently when more than 90% of blocks are already allocated, we stop scheduling new prefill requests.
|
| 166 |
+
"""
|
| 167 |
+
super().__init__(cache, retain_cache_on_finish)
|
| 168 |
+
self.safety_margin = safety_margin
|
| 169 |
+
|
| 170 |
+
@traced
|
| 171 |
+
def schedule_batch(self, token_budget: int) -> list[RequestState]:
|
| 172 |
+
priority_states: list[RequestState] = []
|
| 173 |
+
second_priority_states: list[RequestState] = []
|
| 174 |
+
scheduled_requests = []
|
| 175 |
+
|
| 176 |
+
for state in self.active_requests.values():
|
| 177 |
+
if state.status == RequestStatus.DECODING:
|
| 178 |
+
priority_states.append(state)
|
| 179 |
+
if state.status in [RequestStatus.SPLIT_PENDING_REMAINDER, RequestStatus.PREFILLING_SPLIT]:
|
| 180 |
+
second_priority_states.append(state)
|
| 181 |
+
|
| 182 |
+
# Add waiting requests to second priority
|
| 183 |
+
for req_id in self.waiting_requests_order:
|
| 184 |
+
second_priority_states.append(self.waiting_requests[req_id])
|
| 185 |
+
|
| 186 |
+
candidates = priority_states + second_priority_states
|
| 187 |
+
request_ids_to_remove_from_waiting = set()
|
| 188 |
+
safety_margins = self.safety_margin * self.cache.num_blocks
|
| 189 |
+
|
| 190 |
+
for state in candidates:
|
| 191 |
+
# If we are out the safety margin, we only accept decoding requests or the first prefill request
|
| 192 |
+
num_free_blocks = self.cache.get_num_free_blocks()
|
| 193 |
+
outside_safety_margin = num_free_blocks < safety_margins
|
| 194 |
+
if outside_safety_margin and scheduled_requests and state.status != RequestStatus.DECODING:
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
|
| 198 |
+
request_len = len(state.prompt_ids)
|
| 199 |
+
if not self._allocate_blocks_if_needed(
|
| 200 |
+
state, len(state.prompt_ids)
|
| 201 |
+
): # don't schedule if we can't allocate blocks
|
| 202 |
+
if len(self.cache._free_blocks) == 0:
|
| 203 |
+
break
|
| 204 |
+
continue
|
| 205 |
+
|
| 206 |
+
@traced
|
| 207 |
+
def _add_to_scheduled_requests(state: RequestState):
|
| 208 |
+
scheduled_requests.append(state)
|
| 209 |
+
|
| 210 |
+
_add_to_scheduled_requests(state)
|
| 211 |
+
|
| 212 |
+
token_budget -= request_len
|
| 213 |
+
|
| 214 |
+
@traced
|
| 215 |
+
def _remove_from_waiting_requests(state: RequestState):
|
| 216 |
+
req_id = state.request_id
|
| 217 |
+
if req_id in self.waiting_requests:
|
| 218 |
+
del self.waiting_requests[req_id]
|
| 219 |
+
request_ids_to_remove_from_waiting.add(req_id)
|
| 220 |
+
|
| 221 |
+
_remove_from_waiting_requests(state)
|
| 222 |
+
|
| 223 |
+
if token_budget == 0:
|
| 224 |
+
break
|
| 225 |
+
|
| 226 |
+
self.waiting_requests_order = deque(
|
| 227 |
+
[req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting]
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
return scheduled_requests
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# FIXME: prioritize adding from waiting reqs before scheduling `RequestStatus.DECODING` when cache space allows it
|
| 234 |
+
@attach_tracer()
|
| 235 |
+
class PrefillFirstScheduler(Scheduler):
|
| 236 |
+
"""Scheduler that prioritizes split prefill requests over decoding requests. This scheduler ensures that split
|
| 237 |
+
prefill requests (which are continuations of partially processed prompts) are completed before processing new
|
| 238 |
+
decoding requests."""
|
| 239 |
+
|
| 240 |
+
@traced
|
| 241 |
+
def schedule_batch(self, token_budget: int) -> list[RequestState]:
|
| 242 |
+
priority_states: list[RequestState] = []
|
| 243 |
+
second_priority_states: list[RequestState] = []
|
| 244 |
+
scheduled_requests = []
|
| 245 |
+
|
| 246 |
+
for state in self.active_requests.values():
|
| 247 |
+
# XXX: when cache is full, state can stay on `PREFILLING_SPLIT` so we need to take those into account
|
| 248 |
+
if state.status in [RequestStatus.PREFILLING_SPLIT, RequestStatus.SPLIT_PENDING_REMAINDER]:
|
| 249 |
+
priority_states.append(state)
|
| 250 |
+
elif state.status == RequestStatus.DECODING:
|
| 251 |
+
second_priority_states.append(state)
|
| 252 |
+
|
| 253 |
+
for req_id in self.waiting_requests_order:
|
| 254 |
+
second_priority_states.append(self.waiting_requests[req_id])
|
| 255 |
+
|
| 256 |
+
candidates = priority_states + second_priority_states
|
| 257 |
+
|
| 258 |
+
request_ids_to_remove_from_waiting = set()
|
| 259 |
+
|
| 260 |
+
for state in candidates:
|
| 261 |
+
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
|
| 262 |
+
request_len = len(state.prompt_ids)
|
| 263 |
+
if not self._allocate_blocks_if_needed(
|
| 264 |
+
state, len(state.prompt_ids)
|
| 265 |
+
): # don't schedule if we can't allocate blocks
|
| 266 |
+
if len(self.cache._free_blocks) == 0:
|
| 267 |
+
break
|
| 268 |
+
continue
|
| 269 |
+
|
| 270 |
+
@traced
|
| 271 |
+
def _add_to_scheduled_requests(state: RequestState):
|
| 272 |
+
scheduled_requests.append(state)
|
| 273 |
+
|
| 274 |
+
_add_to_scheduled_requests(state)
|
| 275 |
+
|
| 276 |
+
token_budget -= request_len
|
| 277 |
+
|
| 278 |
+
@traced
|
| 279 |
+
def _remove_from_waiting_requests(state: RequestState):
|
| 280 |
+
req_id = state.request_id
|
| 281 |
+
if req_id in self.waiting_requests:
|
| 282 |
+
del self.waiting_requests[req_id]
|
| 283 |
+
request_ids_to_remove_from_waiting.add(req_id)
|
| 284 |
+
|
| 285 |
+
_remove_from_waiting_requests(state)
|
| 286 |
+
|
| 287 |
+
if token_budget == 0:
|
| 288 |
+
break
|
| 289 |
+
|
| 290 |
+
self.waiting_requests_order = deque(
|
| 291 |
+
[req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting]
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
return scheduled_requests
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
SCHEDULER_MAPPING = {
|
| 298 |
+
"fifo": FIFOScheduler,
|
| 299 |
+
"prefill_first": PrefillFirstScheduler,
|
| 300 |
+
}
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_albert import *
|
| 22 |
+
from .modeling_albert import *
|
| 23 |
+
from .modeling_flax_albert import *
|
| 24 |
+
from .modeling_tf_albert import *
|
| 25 |
+
from .tokenization_albert import *
|
| 26 |
+
from .tokenization_albert_fast import *
|
| 27 |
+
else:
|
| 28 |
+
import sys
|
| 29 |
+
|
| 30 |
+
_file = globals()["__file__"]
|
| 31 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/configuration_albert.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""ALBERT model configuration"""
|
| 17 |
+
|
| 18 |
+
from collections import OrderedDict
|
| 19 |
+
from collections.abc import Mapping
|
| 20 |
+
|
| 21 |
+
from ...configuration_utils import PretrainedConfig
|
| 22 |
+
from ...onnx import OnnxConfig
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AlbertConfig(PretrainedConfig):
|
| 26 |
+
r"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`AlbertModel`] or a [`TFAlbertModel`]. It is used
|
| 28 |
+
to instantiate an ALBERT model according to the specified arguments, defining the model architecture. Instantiating
|
| 29 |
+
a configuration with the defaults will yield a similar configuration to that of the ALBERT
|
| 30 |
+
[albert/albert-xxlarge-v2](https://huggingface.co/albert/albert-xxlarge-v2) architecture.
|
| 31 |
+
|
| 32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 33 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vocab_size (`int`, *optional*, defaults to 30000):
|
| 37 |
+
Vocabulary size of the ALBERT model. Defines the number of different tokens that can be represented by the
|
| 38 |
+
`inputs_ids` passed when calling [`AlbertModel`] or [`TFAlbertModel`].
|
| 39 |
+
embedding_size (`int`, *optional*, defaults to 128):
|
| 40 |
+
Dimensionality of vocabulary embeddings.
|
| 41 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 42 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 43 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 44 |
+
Number of hidden layers in the Transformer encoder.
|
| 45 |
+
num_hidden_groups (`int`, *optional*, defaults to 1):
|
| 46 |
+
Number of groups for the hidden layers, parameters in the same group are shared.
|
| 47 |
+
num_attention_heads (`int`, *optional*, defaults to 64):
|
| 48 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 49 |
+
intermediate_size (`int`, *optional*, defaults to 16384):
|
| 50 |
+
The dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
| 51 |
+
inner_group_num (`int`, *optional*, defaults to 1):
|
| 52 |
+
The number of inner repetition of attention and ffn.
|
| 53 |
+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu_new"`):
|
| 54 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 55 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 56 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0):
|
| 57 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 58 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0):
|
| 59 |
+
The dropout ratio for the attention probabilities.
|
| 60 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 61 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 62 |
+
(e.g., 512 or 1024 or 2048).
|
| 63 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
| 64 |
+
The vocabulary size of the `token_type_ids` passed when calling [`AlbertModel`] or [`TFAlbertModel`].
|
| 65 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 66 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 67 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 68 |
+
The epsilon used by the layer normalization layers.
|
| 69 |
+
classifier_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 70 |
+
The dropout ratio for attached classifiers.
|
| 71 |
+
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
| 72 |
+
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
| 73 |
+
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
| 74 |
+
[Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
|
| 75 |
+
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
| 76 |
+
with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
|
| 77 |
+
pad_token_id (`int`, *optional*, defaults to 0):
|
| 78 |
+
Padding token id.
|
| 79 |
+
bos_token_id (`int`, *optional*, defaults to 2):
|
| 80 |
+
Beginning of stream token id.
|
| 81 |
+
eos_token_id (`int`, *optional*, defaults to 3):
|
| 82 |
+
End of stream token id.
|
| 83 |
+
|
| 84 |
+
Examples:
|
| 85 |
+
|
| 86 |
+
```python
|
| 87 |
+
>>> from transformers import AlbertConfig, AlbertModel
|
| 88 |
+
|
| 89 |
+
>>> # Initializing an ALBERT-xxlarge style configuration
|
| 90 |
+
>>> albert_xxlarge_configuration = AlbertConfig()
|
| 91 |
+
|
| 92 |
+
>>> # Initializing an ALBERT-base style configuration
|
| 93 |
+
>>> albert_base_configuration = AlbertConfig(
|
| 94 |
+
... hidden_size=768,
|
| 95 |
+
... num_attention_heads=12,
|
| 96 |
+
... intermediate_size=3072,
|
| 97 |
+
... )
|
| 98 |
+
|
| 99 |
+
>>> # Initializing a model (with random weights) from the ALBERT-base style configuration
|
| 100 |
+
>>> model = AlbertModel(albert_xxlarge_configuration)
|
| 101 |
+
|
| 102 |
+
>>> # Accessing the model configuration
|
| 103 |
+
>>> configuration = model.config
|
| 104 |
+
```"""
|
| 105 |
+
|
| 106 |
+
model_type = "albert"
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
vocab_size=30000,
|
| 111 |
+
embedding_size=128,
|
| 112 |
+
hidden_size=4096,
|
| 113 |
+
num_hidden_layers=12,
|
| 114 |
+
num_hidden_groups=1,
|
| 115 |
+
num_attention_heads=64,
|
| 116 |
+
intermediate_size=16384,
|
| 117 |
+
inner_group_num=1,
|
| 118 |
+
hidden_act="gelu_new",
|
| 119 |
+
hidden_dropout_prob=0,
|
| 120 |
+
attention_probs_dropout_prob=0,
|
| 121 |
+
max_position_embeddings=512,
|
| 122 |
+
type_vocab_size=2,
|
| 123 |
+
initializer_range=0.02,
|
| 124 |
+
layer_norm_eps=1e-12,
|
| 125 |
+
classifier_dropout_prob=0.1,
|
| 126 |
+
position_embedding_type="absolute",
|
| 127 |
+
pad_token_id=0,
|
| 128 |
+
bos_token_id=2,
|
| 129 |
+
eos_token_id=3,
|
| 130 |
+
**kwargs,
|
| 131 |
+
):
|
| 132 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 133 |
+
|
| 134 |
+
self.vocab_size = vocab_size
|
| 135 |
+
self.embedding_size = embedding_size
|
| 136 |
+
self.hidden_size = hidden_size
|
| 137 |
+
self.num_hidden_layers = num_hidden_layers
|
| 138 |
+
self.num_hidden_groups = num_hidden_groups
|
| 139 |
+
self.num_attention_heads = num_attention_heads
|
| 140 |
+
self.inner_group_num = inner_group_num
|
| 141 |
+
self.hidden_act = hidden_act
|
| 142 |
+
self.intermediate_size = intermediate_size
|
| 143 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 144 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 145 |
+
self.max_position_embeddings = max_position_embeddings
|
| 146 |
+
self.type_vocab_size = type_vocab_size
|
| 147 |
+
self.initializer_range = initializer_range
|
| 148 |
+
self.layer_norm_eps = layer_norm_eps
|
| 149 |
+
self.classifier_dropout_prob = classifier_dropout_prob
|
| 150 |
+
self.position_embedding_type = position_embedding_type
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert
|
| 154 |
+
class AlbertOnnxConfig(OnnxConfig):
|
| 155 |
+
@property
|
| 156 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 157 |
+
if self.task == "multiple-choice":
|
| 158 |
+
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
| 159 |
+
else:
|
| 160 |
+
dynamic_axis = {0: "batch", 1: "sequence"}
|
| 161 |
+
return OrderedDict(
|
| 162 |
+
[
|
| 163 |
+
("input_ids", dynamic_axis),
|
| 164 |
+
("attention_mask", dynamic_axis),
|
| 165 |
+
("token_type_ids", dynamic_axis),
|
| 166 |
+
]
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
__all__ = ["AlbertConfig", "AlbertOnnxConfig"]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_albert.py
ADDED
|
@@ -0,0 +1,1349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch ALBERT model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import os
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Optional, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch import nn
|
| 24 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 25 |
+
|
| 26 |
+
from ...activations import ACT2FN
|
| 27 |
+
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
|
| 28 |
+
from ...modeling_outputs import (
|
| 29 |
+
BaseModelOutput,
|
| 30 |
+
BaseModelOutputWithPooling,
|
| 31 |
+
MaskedLMOutput,
|
| 32 |
+
MultipleChoiceModelOutput,
|
| 33 |
+
QuestionAnsweringModelOutput,
|
| 34 |
+
SequenceClassifierOutput,
|
| 35 |
+
TokenClassifierOutput,
|
| 36 |
+
)
|
| 37 |
+
from ...modeling_utils import PreTrainedModel
|
| 38 |
+
from ...pytorch_utils import (
|
| 39 |
+
apply_chunking_to_forward,
|
| 40 |
+
find_pruneable_heads_and_indices,
|
| 41 |
+
prune_linear_layer,
|
| 42 |
+
)
|
| 43 |
+
from ...utils import ModelOutput, auto_docstring, logging
|
| 44 |
+
from .configuration_albert import AlbertConfig
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
logger = logging.get_logger(__name__)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
| 51 |
+
"""Load tf checkpoints in a pytorch model."""
|
| 52 |
+
try:
|
| 53 |
+
import re
|
| 54 |
+
|
| 55 |
+
import numpy as np
|
| 56 |
+
import tensorflow as tf
|
| 57 |
+
except ImportError:
|
| 58 |
+
logger.error(
|
| 59 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
| 60 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
| 61 |
+
)
|
| 62 |
+
raise
|
| 63 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
| 64 |
+
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
| 65 |
+
# Load weights from TF model
|
| 66 |
+
init_vars = tf.train.list_variables(tf_path)
|
| 67 |
+
names = []
|
| 68 |
+
arrays = []
|
| 69 |
+
for name, shape in init_vars:
|
| 70 |
+
logger.info(f"Loading TF weight {name} with shape {shape}")
|
| 71 |
+
array = tf.train.load_variable(tf_path, name)
|
| 72 |
+
names.append(name)
|
| 73 |
+
arrays.append(array)
|
| 74 |
+
|
| 75 |
+
for name, array in zip(names, arrays):
|
| 76 |
+
print(name)
|
| 77 |
+
|
| 78 |
+
for name, array in zip(names, arrays):
|
| 79 |
+
original_name = name
|
| 80 |
+
|
| 81 |
+
# If saved from the TF HUB module
|
| 82 |
+
name = name.replace("module/", "")
|
| 83 |
+
|
| 84 |
+
# Renaming and simplifying
|
| 85 |
+
name = name.replace("ffn_1", "ffn")
|
| 86 |
+
name = name.replace("bert/", "albert/")
|
| 87 |
+
name = name.replace("attention_1", "attention")
|
| 88 |
+
name = name.replace("transform/", "")
|
| 89 |
+
name = name.replace("LayerNorm_1", "full_layer_layer_norm")
|
| 90 |
+
name = name.replace("LayerNorm", "attention/LayerNorm")
|
| 91 |
+
name = name.replace("transformer/", "")
|
| 92 |
+
|
| 93 |
+
# The feed forward layer had an 'intermediate' step which has been abstracted away
|
| 94 |
+
name = name.replace("intermediate/dense/", "")
|
| 95 |
+
name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")
|
| 96 |
+
|
| 97 |
+
# ALBERT attention was split between self and output which have been abstracted away
|
| 98 |
+
name = name.replace("/output/", "/")
|
| 99 |
+
name = name.replace("/self/", "/")
|
| 100 |
+
|
| 101 |
+
# The pooler is a linear layer
|
| 102 |
+
name = name.replace("pooler/dense", "pooler")
|
| 103 |
+
|
| 104 |
+
# The classifier was simplified to predictions from cls/predictions
|
| 105 |
+
name = name.replace("cls/predictions", "predictions")
|
| 106 |
+
name = name.replace("predictions/attention", "predictions")
|
| 107 |
+
|
| 108 |
+
# Naming was changed to be more explicit
|
| 109 |
+
name = name.replace("embeddings/attention", "embeddings")
|
| 110 |
+
name = name.replace("inner_group_", "albert_layers/")
|
| 111 |
+
name = name.replace("group_", "albert_layer_groups/")
|
| 112 |
+
|
| 113 |
+
# Classifier
|
| 114 |
+
if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
|
| 115 |
+
name = "classifier/" + name
|
| 116 |
+
|
| 117 |
+
# No ALBERT model currently handles the next sentence prediction task
|
| 118 |
+
if "seq_relationship" in name:
|
| 119 |
+
name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
|
| 120 |
+
name = name.replace("weights", "weight")
|
| 121 |
+
|
| 122 |
+
name = name.split("/")
|
| 123 |
+
|
| 124 |
+
# Ignore the gradients applied by the LAMB/ADAM optimizers.
|
| 125 |
+
if (
|
| 126 |
+
"adam_m" in name
|
| 127 |
+
or "adam_v" in name
|
| 128 |
+
or "AdamWeightDecayOptimizer" in name
|
| 129 |
+
or "AdamWeightDecayOptimizer_1" in name
|
| 130 |
+
or "global_step" in name
|
| 131 |
+
):
|
| 132 |
+
logger.info(f"Skipping {'/'.join(name)}")
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
pointer = model
|
| 136 |
+
for m_name in name:
|
| 137 |
+
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
| 138 |
+
scope_names = re.split(r"_(\d+)", m_name)
|
| 139 |
+
else:
|
| 140 |
+
scope_names = [m_name]
|
| 141 |
+
|
| 142 |
+
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
| 143 |
+
pointer = getattr(pointer, "weight")
|
| 144 |
+
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
| 145 |
+
pointer = getattr(pointer, "bias")
|
| 146 |
+
elif scope_names[0] == "output_weights":
|
| 147 |
+
pointer = getattr(pointer, "weight")
|
| 148 |
+
elif scope_names[0] == "squad":
|
| 149 |
+
pointer = getattr(pointer, "classifier")
|
| 150 |
+
else:
|
| 151 |
+
try:
|
| 152 |
+
pointer = getattr(pointer, scope_names[0])
|
| 153 |
+
except AttributeError:
|
| 154 |
+
logger.info(f"Skipping {'/'.join(name)}")
|
| 155 |
+
continue
|
| 156 |
+
if len(scope_names) >= 2:
|
| 157 |
+
num = int(scope_names[1])
|
| 158 |
+
pointer = pointer[num]
|
| 159 |
+
|
| 160 |
+
if m_name[-11:] == "_embeddings":
|
| 161 |
+
pointer = getattr(pointer, "weight")
|
| 162 |
+
elif m_name == "kernel":
|
| 163 |
+
array = np.transpose(array)
|
| 164 |
+
try:
|
| 165 |
+
if pointer.shape != array.shape:
|
| 166 |
+
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
| 167 |
+
except ValueError as e:
|
| 168 |
+
e.args += (pointer.shape, array.shape)
|
| 169 |
+
raise
|
| 170 |
+
print(f"Initialize PyTorch weight {name} from {original_name}")
|
| 171 |
+
pointer.data = torch.from_numpy(array)
|
| 172 |
+
|
| 173 |
+
return model
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class AlbertEmbeddings(nn.Module):
|
| 177 |
+
"""
|
| 178 |
+
Construct the embeddings from word, position and token_type embeddings.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
def __init__(self, config: AlbertConfig):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
|
| 184 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
|
| 185 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
|
| 186 |
+
|
| 187 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 188 |
+
# any TensorFlow checkpoint file
|
| 189 |
+
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
|
| 190 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 191 |
+
|
| 192 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 193 |
+
self.register_buffer(
|
| 194 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 195 |
+
)
|
| 196 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 197 |
+
self.register_buffer(
|
| 198 |
+
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
| 202 |
+
def forward(
|
| 203 |
+
self,
|
| 204 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 205 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 206 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 207 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 208 |
+
past_key_values_length: int = 0,
|
| 209 |
+
) -> torch.Tensor:
|
| 210 |
+
if input_ids is not None:
|
| 211 |
+
input_shape = input_ids.size()
|
| 212 |
+
else:
|
| 213 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 214 |
+
|
| 215 |
+
seq_length = input_shape[1]
|
| 216 |
+
|
| 217 |
+
if position_ids is None:
|
| 218 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
| 219 |
+
|
| 220 |
+
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
| 221 |
+
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
| 222 |
+
# issue #5664
|
| 223 |
+
if token_type_ids is None:
|
| 224 |
+
if hasattr(self, "token_type_ids"):
|
| 225 |
+
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
| 226 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
| 227 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 228 |
+
else:
|
| 229 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
| 230 |
+
|
| 231 |
+
if inputs_embeds is None:
|
| 232 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 233 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 234 |
+
|
| 235 |
+
embeddings = inputs_embeds + token_type_embeddings
|
| 236 |
+
if self.position_embedding_type == "absolute":
|
| 237 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 238 |
+
embeddings += position_embeddings
|
| 239 |
+
embeddings = self.LayerNorm(embeddings)
|
| 240 |
+
embeddings = self.dropout(embeddings)
|
| 241 |
+
return embeddings
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class AlbertAttention(nn.Module):
|
| 245 |
+
def __init__(self, config: AlbertConfig):
|
| 246 |
+
super().__init__()
|
| 247 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 248 |
+
raise ValueError(
|
| 249 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 250 |
+
f"heads ({config.num_attention_heads}"
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
self.num_attention_heads = config.num_attention_heads
|
| 254 |
+
self.hidden_size = config.hidden_size
|
| 255 |
+
self.attention_head_size = config.hidden_size // config.num_attention_heads
|
| 256 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 257 |
+
|
| 258 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 259 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 260 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 261 |
+
|
| 262 |
+
self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 263 |
+
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 264 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 265 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 266 |
+
self.pruned_heads = set()
|
| 267 |
+
|
| 268 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 269 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 270 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 271 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 272 |
+
|
| 273 |
+
def prune_heads(self, heads: list[int]) -> None:
|
| 274 |
+
if len(heads) == 0:
|
| 275 |
+
return
|
| 276 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 277 |
+
heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Prune linear layers
|
| 281 |
+
self.query = prune_linear_layer(self.query, index)
|
| 282 |
+
self.key = prune_linear_layer(self.key, index)
|
| 283 |
+
self.value = prune_linear_layer(self.value, index)
|
| 284 |
+
self.dense = prune_linear_layer(self.dense, index, dim=1)
|
| 285 |
+
|
| 286 |
+
# Update hyper params and store pruned heads
|
| 287 |
+
self.num_attention_heads = self.num_attention_heads - len(heads)
|
| 288 |
+
self.all_head_size = self.attention_head_size * self.num_attention_heads
|
| 289 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 290 |
+
|
| 291 |
+
def forward(
|
| 292 |
+
self,
|
| 293 |
+
hidden_states: torch.Tensor,
|
| 294 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 295 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 296 |
+
output_attentions: bool = False,
|
| 297 |
+
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
| 298 |
+
batch_size, seq_length, _ = hidden_states.shape
|
| 299 |
+
query_layer = self.query(hidden_states)
|
| 300 |
+
key_layer = self.key(hidden_states)
|
| 301 |
+
value_layer = self.value(hidden_states)
|
| 302 |
+
query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
|
| 303 |
+
1, 2
|
| 304 |
+
)
|
| 305 |
+
key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
|
| 306 |
+
value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
|
| 307 |
+
1, 2
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 311 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 312 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 313 |
+
|
| 314 |
+
if attention_mask is not None:
|
| 315 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 316 |
+
attention_scores = attention_scores + attention_mask
|
| 317 |
+
|
| 318 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 319 |
+
seq_length = hidden_states.size()[1]
|
| 320 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 321 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 322 |
+
distance = position_ids_l - position_ids_r
|
| 323 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 324 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 325 |
+
|
| 326 |
+
if self.position_embedding_type == "relative_key":
|
| 327 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 328 |
+
attention_scores = attention_scores + relative_position_scores
|
| 329 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 330 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 331 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 332 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 333 |
+
|
| 334 |
+
# Normalize the attention scores to probabilities.
|
| 335 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 336 |
+
|
| 337 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 338 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 339 |
+
attention_probs = self.attention_dropout(attention_probs)
|
| 340 |
+
|
| 341 |
+
# Mask heads if we want to
|
| 342 |
+
if head_mask is not None:
|
| 343 |
+
attention_probs = attention_probs * head_mask
|
| 344 |
+
|
| 345 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 346 |
+
context_layer = context_layer.transpose(2, 1).flatten(2)
|
| 347 |
+
|
| 348 |
+
projected_context_layer = self.dense(context_layer)
|
| 349 |
+
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
|
| 350 |
+
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
|
| 351 |
+
return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class AlbertSdpaAttention(AlbertAttention):
|
| 355 |
+
def __init__(self, config):
|
| 356 |
+
super().__init__(config)
|
| 357 |
+
self.dropout_prob = config.attention_probs_dropout_prob
|
| 358 |
+
|
| 359 |
+
def forward(
|
| 360 |
+
self,
|
| 361 |
+
hidden_states: torch.Tensor,
|
| 362 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 363 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 364 |
+
output_attentions: bool = False,
|
| 365 |
+
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
| 366 |
+
if self.position_embedding_type != "absolute" or output_attentions:
|
| 367 |
+
logger.warning(
|
| 368 |
+
"AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
| 369 |
+
"non-absolute `position_embedding_type` or `output_attentions=True` . Falling back to "
|
| 370 |
+
"the eager attention implementation, but specifying the eager implementation will be required from "
|
| 371 |
+
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
|
| 372 |
+
'`attn_implementation="eager"` when loading the model.'
|
| 373 |
+
)
|
| 374 |
+
return super().forward(hidden_states, attention_mask, output_attentions=output_attentions)
|
| 375 |
+
|
| 376 |
+
batch_size, seq_len, _ = hidden_states.size()
|
| 377 |
+
query_layer = (
|
| 378 |
+
self.query(hidden_states)
|
| 379 |
+
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
| 380 |
+
.transpose(1, 2)
|
| 381 |
+
)
|
| 382 |
+
key_layer = (
|
| 383 |
+
self.key(hidden_states)
|
| 384 |
+
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
| 385 |
+
.transpose(1, 2)
|
| 386 |
+
)
|
| 387 |
+
value_layer = (
|
| 388 |
+
self.value(hidden_states)
|
| 389 |
+
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
| 390 |
+
.transpose(1, 2)
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
attention_output = torch.nn.functional.scaled_dot_product_attention(
|
| 394 |
+
query=query_layer,
|
| 395 |
+
key=key_layer,
|
| 396 |
+
value=value_layer,
|
| 397 |
+
attn_mask=attention_mask,
|
| 398 |
+
dropout_p=self.dropout_prob if self.training else 0.0,
|
| 399 |
+
is_causal=False,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
attention_output = attention_output.transpose(1, 2)
|
| 403 |
+
attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size)
|
| 404 |
+
|
| 405 |
+
projected_context_layer = self.dense(attention_output)
|
| 406 |
+
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
|
| 407 |
+
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
|
| 408 |
+
return (layernormed_context_layer,)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
ALBERT_ATTENTION_CLASSES = {
|
| 412 |
+
"eager": AlbertAttention,
|
| 413 |
+
"sdpa": AlbertSdpaAttention,
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class AlbertLayer(nn.Module):
|
| 418 |
+
def __init__(self, config: AlbertConfig):
|
| 419 |
+
super().__init__()
|
| 420 |
+
|
| 421 |
+
self.config = config
|
| 422 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 423 |
+
self.seq_len_dim = 1
|
| 424 |
+
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 425 |
+
self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config)
|
| 426 |
+
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 427 |
+
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 428 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 429 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 430 |
+
|
| 431 |
+
def forward(
|
| 432 |
+
self,
|
| 433 |
+
hidden_states: torch.Tensor,
|
| 434 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 435 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 436 |
+
output_attentions: bool = False,
|
| 437 |
+
output_hidden_states: bool = False,
|
| 438 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 439 |
+
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
|
| 440 |
+
|
| 441 |
+
ffn_output = apply_chunking_to_forward(
|
| 442 |
+
self.ff_chunk,
|
| 443 |
+
self.chunk_size_feed_forward,
|
| 444 |
+
self.seq_len_dim,
|
| 445 |
+
attention_output[0],
|
| 446 |
+
)
|
| 447 |
+
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
|
| 448 |
+
|
| 449 |
+
return (hidden_states,) + attention_output[1:] # add attentions if we output them
|
| 450 |
+
|
| 451 |
+
def ff_chunk(self, attention_output: torch.Tensor) -> torch.Tensor:
|
| 452 |
+
ffn_output = self.ffn(attention_output)
|
| 453 |
+
ffn_output = self.activation(ffn_output)
|
| 454 |
+
ffn_output = self.ffn_output(ffn_output)
|
| 455 |
+
return ffn_output
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
class AlbertLayerGroup(nn.Module):
|
| 459 |
+
def __init__(self, config: AlbertConfig):
|
| 460 |
+
super().__init__()
|
| 461 |
+
|
| 462 |
+
self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
|
| 463 |
+
|
| 464 |
+
def forward(
|
| 465 |
+
self,
|
| 466 |
+
hidden_states: torch.Tensor,
|
| 467 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 468 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 469 |
+
output_attentions: bool = False,
|
| 470 |
+
output_hidden_states: bool = False,
|
| 471 |
+
) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
|
| 472 |
+
layer_hidden_states = ()
|
| 473 |
+
layer_attentions = ()
|
| 474 |
+
|
| 475 |
+
for layer_index, albert_layer in enumerate(self.albert_layers):
|
| 476 |
+
layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
|
| 477 |
+
hidden_states = layer_output[0]
|
| 478 |
+
|
| 479 |
+
if output_attentions:
|
| 480 |
+
layer_attentions = layer_attentions + (layer_output[1],)
|
| 481 |
+
|
| 482 |
+
if output_hidden_states:
|
| 483 |
+
layer_hidden_states = layer_hidden_states + (hidden_states,)
|
| 484 |
+
|
| 485 |
+
outputs = (hidden_states,)
|
| 486 |
+
if output_hidden_states:
|
| 487 |
+
outputs = outputs + (layer_hidden_states,)
|
| 488 |
+
if output_attentions:
|
| 489 |
+
outputs = outputs + (layer_attentions,)
|
| 490 |
+
return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class AlbertTransformer(nn.Module):
|
| 494 |
+
def __init__(self, config: AlbertConfig):
|
| 495 |
+
super().__init__()
|
| 496 |
+
|
| 497 |
+
self.config = config
|
| 498 |
+
self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
|
| 499 |
+
self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
|
| 500 |
+
|
| 501 |
+
def forward(
|
| 502 |
+
self,
|
| 503 |
+
hidden_states: torch.Tensor,
|
| 504 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 505 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 506 |
+
output_attentions: bool = False,
|
| 507 |
+
output_hidden_states: bool = False,
|
| 508 |
+
return_dict: bool = True,
|
| 509 |
+
) -> Union[BaseModelOutput, tuple]:
|
| 510 |
+
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
| 511 |
+
|
| 512 |
+
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
| 513 |
+
all_attentions = () if output_attentions else None
|
| 514 |
+
|
| 515 |
+
head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask
|
| 516 |
+
|
| 517 |
+
for i in range(self.config.num_hidden_layers):
|
| 518 |
+
# Number of layers in a hidden group
|
| 519 |
+
layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
|
| 520 |
+
|
| 521 |
+
# Index of the hidden group
|
| 522 |
+
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
|
| 523 |
+
|
| 524 |
+
layer_group_output = self.albert_layer_groups[group_idx](
|
| 525 |
+
hidden_states,
|
| 526 |
+
attention_mask,
|
| 527 |
+
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
|
| 528 |
+
output_attentions,
|
| 529 |
+
output_hidden_states,
|
| 530 |
+
)
|
| 531 |
+
hidden_states = layer_group_output[0]
|
| 532 |
+
|
| 533 |
+
if output_attentions:
|
| 534 |
+
all_attentions = all_attentions + layer_group_output[-1]
|
| 535 |
+
|
| 536 |
+
if output_hidden_states:
|
| 537 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 538 |
+
|
| 539 |
+
if not return_dict:
|
| 540 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
| 541 |
+
return BaseModelOutput(
|
| 542 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
@auto_docstring
|
| 547 |
+
class AlbertPreTrainedModel(PreTrainedModel):
|
| 548 |
+
config: AlbertConfig
|
| 549 |
+
load_tf_weights = load_tf_weights_in_albert
|
| 550 |
+
base_model_prefix = "albert"
|
| 551 |
+
_supports_sdpa = True
|
| 552 |
+
|
| 553 |
+
def _init_weights(self, module):
|
| 554 |
+
"""Initialize the weights."""
|
| 555 |
+
if isinstance(module, nn.Linear):
|
| 556 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 557 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 558 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 559 |
+
if module.bias is not None:
|
| 560 |
+
module.bias.data.zero_()
|
| 561 |
+
elif isinstance(module, nn.Embedding):
|
| 562 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 563 |
+
if module.padding_idx is not None:
|
| 564 |
+
module.weight.data[module.padding_idx].zero_()
|
| 565 |
+
elif isinstance(module, nn.LayerNorm):
|
| 566 |
+
module.bias.data.zero_()
|
| 567 |
+
module.weight.data.fill_(1.0)
|
| 568 |
+
elif isinstance(module, AlbertMLMHead):
|
| 569 |
+
module.bias.data.zero_()
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
@dataclass
|
| 573 |
+
@auto_docstring(
|
| 574 |
+
custom_intro="""
|
| 575 |
+
Output type of [`AlbertForPreTraining`].
|
| 576 |
+
"""
|
| 577 |
+
)
|
| 578 |
+
class AlbertForPreTrainingOutput(ModelOutput):
|
| 579 |
+
r"""
|
| 580 |
+
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
| 581 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction
|
| 582 |
+
(classification) loss.
|
| 583 |
+
prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 584 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 585 |
+
sop_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
|
| 586 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
| 587 |
+
before SoftMax).
|
| 588 |
+
"""
|
| 589 |
+
|
| 590 |
+
loss: Optional[torch.FloatTensor] = None
|
| 591 |
+
prediction_logits: Optional[torch.FloatTensor] = None
|
| 592 |
+
sop_logits: Optional[torch.FloatTensor] = None
|
| 593 |
+
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 594 |
+
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
@auto_docstring
|
| 598 |
+
class AlbertModel(AlbertPreTrainedModel):
|
| 599 |
+
config: AlbertConfig
|
| 600 |
+
base_model_prefix = "albert"
|
| 601 |
+
|
| 602 |
+
def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True):
|
| 603 |
+
r"""
|
| 604 |
+
add_pooling_layer (bool, *optional*, defaults to `True`):
|
| 605 |
+
Whether to add a pooling layer
|
| 606 |
+
"""
|
| 607 |
+
super().__init__(config)
|
| 608 |
+
|
| 609 |
+
self.config = config
|
| 610 |
+
self.embeddings = AlbertEmbeddings(config)
|
| 611 |
+
self.encoder = AlbertTransformer(config)
|
| 612 |
+
if add_pooling_layer:
|
| 613 |
+
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
|
| 614 |
+
self.pooler_activation = nn.Tanh()
|
| 615 |
+
else:
|
| 616 |
+
self.pooler = None
|
| 617 |
+
self.pooler_activation = None
|
| 618 |
+
|
| 619 |
+
self.attn_implementation = config._attn_implementation
|
| 620 |
+
self.position_embedding_type = config.position_embedding_type
|
| 621 |
+
|
| 622 |
+
# Initialize weights and apply final processing
|
| 623 |
+
self.post_init()
|
| 624 |
+
|
| 625 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 626 |
+
return self.embeddings.word_embeddings
|
| 627 |
+
|
| 628 |
+
def set_input_embeddings(self, value: nn.Embedding) -> None:
|
| 629 |
+
self.embeddings.word_embeddings = value
|
| 630 |
+
|
| 631 |
+
def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
|
| 632 |
+
"""
|
| 633 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
|
| 634 |
+
a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT
|
| 635 |
+
model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers.
|
| 636 |
+
|
| 637 |
+
These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
|
| 638 |
+
while [2,3] correspond to the two inner groups of the second hidden layer.
|
| 639 |
+
|
| 640 |
+
Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more
|
| 641 |
+
information about head pruning
|
| 642 |
+
"""
|
| 643 |
+
for layer, heads in heads_to_prune.items():
|
| 644 |
+
group_idx = int(layer / self.config.inner_group_num)
|
| 645 |
+
inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
|
| 646 |
+
self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
|
| 647 |
+
|
| 648 |
+
@auto_docstring
|
| 649 |
+
def forward(
|
| 650 |
+
self,
|
| 651 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 652 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 653 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 654 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 655 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 656 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 657 |
+
output_attentions: Optional[bool] = None,
|
| 658 |
+
output_hidden_states: Optional[bool] = None,
|
| 659 |
+
return_dict: Optional[bool] = None,
|
| 660 |
+
) -> Union[BaseModelOutputWithPooling, tuple]:
|
| 661 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 662 |
+
output_hidden_states = (
|
| 663 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 664 |
+
)
|
| 665 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 666 |
+
|
| 667 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 668 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 669 |
+
elif input_ids is not None:
|
| 670 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 671 |
+
input_shape = input_ids.size()
|
| 672 |
+
elif inputs_embeds is not None:
|
| 673 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 674 |
+
else:
|
| 675 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 676 |
+
|
| 677 |
+
batch_size, seq_length = input_shape
|
| 678 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 679 |
+
|
| 680 |
+
if attention_mask is None:
|
| 681 |
+
attention_mask = torch.ones(input_shape, device=device)
|
| 682 |
+
if token_type_ids is None:
|
| 683 |
+
if hasattr(self.embeddings, "token_type_ids"):
|
| 684 |
+
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
| 685 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
| 686 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 687 |
+
else:
|
| 688 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
| 689 |
+
|
| 690 |
+
embedding_output = self.embeddings(
|
| 691 |
+
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
use_sdpa_attention_mask = (
|
| 695 |
+
self.attn_implementation == "sdpa"
|
| 696 |
+
and self.position_embedding_type == "absolute"
|
| 697 |
+
and head_mask is None
|
| 698 |
+
and not output_attentions
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
if use_sdpa_attention_mask:
|
| 702 |
+
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
| 703 |
+
attention_mask, embedding_output.dtype, tgt_len=seq_length
|
| 704 |
+
)
|
| 705 |
+
else:
|
| 706 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 707 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 708 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
|
| 709 |
+
|
| 710 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 711 |
+
|
| 712 |
+
encoder_outputs = self.encoder(
|
| 713 |
+
embedding_output,
|
| 714 |
+
extended_attention_mask,
|
| 715 |
+
head_mask=head_mask,
|
| 716 |
+
output_attentions=output_attentions,
|
| 717 |
+
output_hidden_states=output_hidden_states,
|
| 718 |
+
return_dict=return_dict,
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
sequence_output = encoder_outputs[0]
|
| 722 |
+
|
| 723 |
+
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
|
| 724 |
+
|
| 725 |
+
if not return_dict:
|
| 726 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 727 |
+
|
| 728 |
+
return BaseModelOutputWithPooling(
|
| 729 |
+
last_hidden_state=sequence_output,
|
| 730 |
+
pooler_output=pooled_output,
|
| 731 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 732 |
+
attentions=encoder_outputs.attentions,
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
@auto_docstring(
|
| 737 |
+
custom_intro="""
|
| 738 |
+
Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
|
| 739 |
+
`sentence order prediction (classification)` head.
|
| 740 |
+
"""
|
| 741 |
+
)
|
| 742 |
+
class AlbertForPreTraining(AlbertPreTrainedModel):
|
| 743 |
+
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
|
| 744 |
+
|
| 745 |
+
def __init__(self, config: AlbertConfig):
|
| 746 |
+
super().__init__(config)
|
| 747 |
+
|
| 748 |
+
self.albert = AlbertModel(config)
|
| 749 |
+
self.predictions = AlbertMLMHead(config)
|
| 750 |
+
self.sop_classifier = AlbertSOPHead(config)
|
| 751 |
+
|
| 752 |
+
# Initialize weights and apply final processing
|
| 753 |
+
self.post_init()
|
| 754 |
+
|
| 755 |
+
def get_output_embeddings(self) -> nn.Linear:
|
| 756 |
+
return self.predictions.decoder
|
| 757 |
+
|
| 758 |
+
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
| 759 |
+
self.predictions.decoder = new_embeddings
|
| 760 |
+
|
| 761 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 762 |
+
return self.albert.embeddings.word_embeddings
|
| 763 |
+
|
| 764 |
+
@auto_docstring
|
| 765 |
+
def forward(
|
| 766 |
+
self,
|
| 767 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 768 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 769 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 770 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 771 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 772 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 773 |
+
labels: Optional[torch.LongTensor] = None,
|
| 774 |
+
sentence_order_label: Optional[torch.LongTensor] = None,
|
| 775 |
+
output_attentions: Optional[bool] = None,
|
| 776 |
+
output_hidden_states: Optional[bool] = None,
|
| 777 |
+
return_dict: Optional[bool] = None,
|
| 778 |
+
) -> Union[AlbertForPreTrainingOutput, tuple]:
|
| 779 |
+
r"""
|
| 780 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 781 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 782 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 783 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 784 |
+
sentence_order_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 785 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
| 786 |
+
(see `input_ids` docstring) Indices should be in `[0, 1]`. `0` indicates original order (sequence A, then
|
| 787 |
+
sequence B), `1` indicates switched order (sequence B, then sequence A).
|
| 788 |
+
|
| 789 |
+
Example:
|
| 790 |
+
|
| 791 |
+
```python
|
| 792 |
+
>>> from transformers import AutoTokenizer, AlbertForPreTraining
|
| 793 |
+
>>> import torch
|
| 794 |
+
|
| 795 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
|
| 796 |
+
>>> model = AlbertForPreTraining.from_pretrained("albert/albert-base-v2")
|
| 797 |
+
|
| 798 |
+
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)
|
| 799 |
+
>>> # Batch size 1
|
| 800 |
+
>>> outputs = model(input_ids)
|
| 801 |
+
|
| 802 |
+
>>> prediction_logits = outputs.prediction_logits
|
| 803 |
+
>>> sop_logits = outputs.sop_logits
|
| 804 |
+
```"""
|
| 805 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 806 |
+
|
| 807 |
+
outputs = self.albert(
|
| 808 |
+
input_ids,
|
| 809 |
+
attention_mask=attention_mask,
|
| 810 |
+
token_type_ids=token_type_ids,
|
| 811 |
+
position_ids=position_ids,
|
| 812 |
+
head_mask=head_mask,
|
| 813 |
+
inputs_embeds=inputs_embeds,
|
| 814 |
+
output_attentions=output_attentions,
|
| 815 |
+
output_hidden_states=output_hidden_states,
|
| 816 |
+
return_dict=return_dict,
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
sequence_output, pooled_output = outputs[:2]
|
| 820 |
+
|
| 821 |
+
prediction_scores = self.predictions(sequence_output)
|
| 822 |
+
sop_scores = self.sop_classifier(pooled_output)
|
| 823 |
+
|
| 824 |
+
total_loss = None
|
| 825 |
+
if labels is not None and sentence_order_label is not None:
|
| 826 |
+
loss_fct = CrossEntropyLoss()
|
| 827 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 828 |
+
sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
|
| 829 |
+
total_loss = masked_lm_loss + sentence_order_loss
|
| 830 |
+
|
| 831 |
+
if not return_dict:
|
| 832 |
+
output = (prediction_scores, sop_scores) + outputs[2:]
|
| 833 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 834 |
+
|
| 835 |
+
return AlbertForPreTrainingOutput(
|
| 836 |
+
loss=total_loss,
|
| 837 |
+
prediction_logits=prediction_scores,
|
| 838 |
+
sop_logits=sop_scores,
|
| 839 |
+
hidden_states=outputs.hidden_states,
|
| 840 |
+
attentions=outputs.attentions,
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
class AlbertMLMHead(nn.Module):
|
| 845 |
+
def __init__(self, config: AlbertConfig):
|
| 846 |
+
super().__init__()
|
| 847 |
+
|
| 848 |
+
self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
|
| 849 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 850 |
+
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
| 851 |
+
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
|
| 852 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 853 |
+
self.decoder.bias = self.bias
|
| 854 |
+
|
| 855 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 856 |
+
hidden_states = self.dense(hidden_states)
|
| 857 |
+
hidden_states = self.activation(hidden_states)
|
| 858 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 859 |
+
hidden_states = self.decoder(hidden_states)
|
| 860 |
+
|
| 861 |
+
prediction_scores = hidden_states
|
| 862 |
+
|
| 863 |
+
return prediction_scores
|
| 864 |
+
|
| 865 |
+
def _tie_weights(self) -> None:
|
| 866 |
+
# For accelerate compatibility and to not break backward compatibility
|
| 867 |
+
if self.decoder.bias.device.type == "meta":
|
| 868 |
+
self.decoder.bias = self.bias
|
| 869 |
+
else:
|
| 870 |
+
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
| 871 |
+
self.bias = self.decoder.bias
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
class AlbertSOPHead(nn.Module):
|
| 875 |
+
def __init__(self, config: AlbertConfig):
|
| 876 |
+
super().__init__()
|
| 877 |
+
|
| 878 |
+
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
| 879 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 880 |
+
|
| 881 |
+
def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
|
| 882 |
+
dropout_pooled_output = self.dropout(pooled_output)
|
| 883 |
+
logits = self.classifier(dropout_pooled_output)
|
| 884 |
+
return logits
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
@auto_docstring
|
| 888 |
+
class AlbertForMaskedLM(AlbertPreTrainedModel):
|
| 889 |
+
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
|
| 890 |
+
|
| 891 |
+
def __init__(self, config):
|
| 892 |
+
super().__init__(config)
|
| 893 |
+
|
| 894 |
+
self.albert = AlbertModel(config, add_pooling_layer=False)
|
| 895 |
+
self.predictions = AlbertMLMHead(config)
|
| 896 |
+
|
| 897 |
+
# Initialize weights and apply final processing
|
| 898 |
+
self.post_init()
|
| 899 |
+
|
| 900 |
+
def get_output_embeddings(self) -> nn.Linear:
|
| 901 |
+
return self.predictions.decoder
|
| 902 |
+
|
| 903 |
+
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
| 904 |
+
self.predictions.decoder = new_embeddings
|
| 905 |
+
self.predictions.bias = new_embeddings.bias
|
| 906 |
+
|
| 907 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 908 |
+
return self.albert.embeddings.word_embeddings
|
| 909 |
+
|
| 910 |
+
@auto_docstring
|
| 911 |
+
def forward(
|
| 912 |
+
self,
|
| 913 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 914 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 915 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 916 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 917 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 918 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 919 |
+
labels: Optional[torch.LongTensor] = None,
|
| 920 |
+
output_attentions: Optional[bool] = None,
|
| 921 |
+
output_hidden_states: Optional[bool] = None,
|
| 922 |
+
return_dict: Optional[bool] = None,
|
| 923 |
+
) -> Union[MaskedLMOutput, tuple]:
|
| 924 |
+
r"""
|
| 925 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 926 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 927 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 928 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 929 |
+
|
| 930 |
+
Example:
|
| 931 |
+
|
| 932 |
+
```python
|
| 933 |
+
>>> import torch
|
| 934 |
+
>>> from transformers import AutoTokenizer, AlbertForMaskedLM
|
| 935 |
+
|
| 936 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
|
| 937 |
+
>>> model = AlbertForMaskedLM.from_pretrained("albert/albert-base-v2")
|
| 938 |
+
|
| 939 |
+
>>> # add mask_token
|
| 940 |
+
>>> inputs = tokenizer("The capital of [MASK] is Paris.", return_tensors="pt")
|
| 941 |
+
>>> with torch.no_grad():
|
| 942 |
+
... logits = model(**inputs).logits
|
| 943 |
+
|
| 944 |
+
>>> # retrieve index of [MASK]
|
| 945 |
+
>>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
|
| 946 |
+
>>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
|
| 947 |
+
>>> tokenizer.decode(predicted_token_id)
|
| 948 |
+
'france'
|
| 949 |
+
```
|
| 950 |
+
|
| 951 |
+
```python
|
| 952 |
+
>>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
|
| 953 |
+
>>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
|
| 954 |
+
>>> outputs = model(**inputs, labels=labels)
|
| 955 |
+
>>> round(outputs.loss.item(), 2)
|
| 956 |
+
0.81
|
| 957 |
+
```
|
| 958 |
+
"""
|
| 959 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 960 |
+
|
| 961 |
+
outputs = self.albert(
|
| 962 |
+
input_ids=input_ids,
|
| 963 |
+
attention_mask=attention_mask,
|
| 964 |
+
token_type_ids=token_type_ids,
|
| 965 |
+
position_ids=position_ids,
|
| 966 |
+
head_mask=head_mask,
|
| 967 |
+
inputs_embeds=inputs_embeds,
|
| 968 |
+
output_attentions=output_attentions,
|
| 969 |
+
output_hidden_states=output_hidden_states,
|
| 970 |
+
return_dict=return_dict,
|
| 971 |
+
)
|
| 972 |
+
sequence_outputs = outputs[0]
|
| 973 |
+
|
| 974 |
+
prediction_scores = self.predictions(sequence_outputs)
|
| 975 |
+
|
| 976 |
+
masked_lm_loss = None
|
| 977 |
+
if labels is not None:
|
| 978 |
+
loss_fct = CrossEntropyLoss()
|
| 979 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 980 |
+
|
| 981 |
+
if not return_dict:
|
| 982 |
+
output = (prediction_scores,) + outputs[2:]
|
| 983 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 984 |
+
|
| 985 |
+
return MaskedLMOutput(
|
| 986 |
+
loss=masked_lm_loss,
|
| 987 |
+
logits=prediction_scores,
|
| 988 |
+
hidden_states=outputs.hidden_states,
|
| 989 |
+
attentions=outputs.attentions,
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
|
| 993 |
+
@auto_docstring(
|
| 994 |
+
custom_intro="""
|
| 995 |
+
Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
| 996 |
+
output) e.g. for GLUE tasks.
|
| 997 |
+
"""
|
| 998 |
+
)
|
| 999 |
+
class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
| 1000 |
+
def __init__(self, config: AlbertConfig):
|
| 1001 |
+
super().__init__(config)
|
| 1002 |
+
self.num_labels = config.num_labels
|
| 1003 |
+
self.config = config
|
| 1004 |
+
|
| 1005 |
+
self.albert = AlbertModel(config)
|
| 1006 |
+
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
| 1007 |
+
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
| 1008 |
+
|
| 1009 |
+
# Initialize weights and apply final processing
|
| 1010 |
+
self.post_init()
|
| 1011 |
+
|
| 1012 |
+
@auto_docstring
|
| 1013 |
+
def forward(
|
| 1014 |
+
self,
|
| 1015 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1016 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1017 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1018 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1019 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1020 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1021 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1022 |
+
output_attentions: Optional[bool] = None,
|
| 1023 |
+
output_hidden_states: Optional[bool] = None,
|
| 1024 |
+
return_dict: Optional[bool] = None,
|
| 1025 |
+
) -> Union[SequenceClassifierOutput, tuple]:
|
| 1026 |
+
r"""
|
| 1027 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1028 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1029 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1030 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1031 |
+
"""
|
| 1032 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1033 |
+
|
| 1034 |
+
outputs = self.albert(
|
| 1035 |
+
input_ids=input_ids,
|
| 1036 |
+
attention_mask=attention_mask,
|
| 1037 |
+
token_type_ids=token_type_ids,
|
| 1038 |
+
position_ids=position_ids,
|
| 1039 |
+
head_mask=head_mask,
|
| 1040 |
+
inputs_embeds=inputs_embeds,
|
| 1041 |
+
output_attentions=output_attentions,
|
| 1042 |
+
output_hidden_states=output_hidden_states,
|
| 1043 |
+
return_dict=return_dict,
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
pooled_output = outputs[1]
|
| 1047 |
+
|
| 1048 |
+
pooled_output = self.dropout(pooled_output)
|
| 1049 |
+
logits = self.classifier(pooled_output)
|
| 1050 |
+
|
| 1051 |
+
loss = None
|
| 1052 |
+
if labels is not None:
|
| 1053 |
+
if self.config.problem_type is None:
|
| 1054 |
+
if self.num_labels == 1:
|
| 1055 |
+
self.config.problem_type = "regression"
|
| 1056 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1057 |
+
self.config.problem_type = "single_label_classification"
|
| 1058 |
+
else:
|
| 1059 |
+
self.config.problem_type = "multi_label_classification"
|
| 1060 |
+
|
| 1061 |
+
if self.config.problem_type == "regression":
|
| 1062 |
+
loss_fct = MSELoss()
|
| 1063 |
+
if self.num_labels == 1:
|
| 1064 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 1065 |
+
else:
|
| 1066 |
+
loss = loss_fct(logits, labels)
|
| 1067 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1068 |
+
loss_fct = CrossEntropyLoss()
|
| 1069 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1070 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1071 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1072 |
+
loss = loss_fct(logits, labels)
|
| 1073 |
+
|
| 1074 |
+
if not return_dict:
|
| 1075 |
+
output = (logits,) + outputs[2:]
|
| 1076 |
+
return ((loss,) + output) if loss is not None else output
|
| 1077 |
+
|
| 1078 |
+
return SequenceClassifierOutput(
|
| 1079 |
+
loss=loss,
|
| 1080 |
+
logits=logits,
|
| 1081 |
+
hidden_states=outputs.hidden_states,
|
| 1082 |
+
attentions=outputs.attentions,
|
| 1083 |
+
)
|
| 1084 |
+
|
| 1085 |
+
|
| 1086 |
+
@auto_docstring
|
| 1087 |
+
class AlbertForTokenClassification(AlbertPreTrainedModel):
|
| 1088 |
+
def __init__(self, config: AlbertConfig):
|
| 1089 |
+
super().__init__(config)
|
| 1090 |
+
self.num_labels = config.num_labels
|
| 1091 |
+
|
| 1092 |
+
self.albert = AlbertModel(config, add_pooling_layer=False)
|
| 1093 |
+
classifier_dropout_prob = (
|
| 1094 |
+
config.classifier_dropout_prob
|
| 1095 |
+
if config.classifier_dropout_prob is not None
|
| 1096 |
+
else config.hidden_dropout_prob
|
| 1097 |
+
)
|
| 1098 |
+
self.dropout = nn.Dropout(classifier_dropout_prob)
|
| 1099 |
+
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
|
| 1100 |
+
|
| 1101 |
+
# Initialize weights and apply final processing
|
| 1102 |
+
self.post_init()
|
| 1103 |
+
|
| 1104 |
+
@auto_docstring
|
| 1105 |
+
def forward(
|
| 1106 |
+
self,
|
| 1107 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1108 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1109 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1110 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1111 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1112 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1113 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1114 |
+
output_attentions: Optional[bool] = None,
|
| 1115 |
+
output_hidden_states: Optional[bool] = None,
|
| 1116 |
+
return_dict: Optional[bool] = None,
|
| 1117 |
+
) -> Union[TokenClassifierOutput, tuple]:
|
| 1118 |
+
r"""
|
| 1119 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1120 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 1121 |
+
"""
|
| 1122 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1123 |
+
|
| 1124 |
+
outputs = self.albert(
|
| 1125 |
+
input_ids,
|
| 1126 |
+
attention_mask=attention_mask,
|
| 1127 |
+
token_type_ids=token_type_ids,
|
| 1128 |
+
position_ids=position_ids,
|
| 1129 |
+
head_mask=head_mask,
|
| 1130 |
+
inputs_embeds=inputs_embeds,
|
| 1131 |
+
output_attentions=output_attentions,
|
| 1132 |
+
output_hidden_states=output_hidden_states,
|
| 1133 |
+
return_dict=return_dict,
|
| 1134 |
+
)
|
| 1135 |
+
|
| 1136 |
+
sequence_output = outputs[0]
|
| 1137 |
+
|
| 1138 |
+
sequence_output = self.dropout(sequence_output)
|
| 1139 |
+
logits = self.classifier(sequence_output)
|
| 1140 |
+
|
| 1141 |
+
loss = None
|
| 1142 |
+
if labels is not None:
|
| 1143 |
+
loss_fct = CrossEntropyLoss()
|
| 1144 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1145 |
+
|
| 1146 |
+
if not return_dict:
|
| 1147 |
+
output = (logits,) + outputs[2:]
|
| 1148 |
+
return ((loss,) + output) if loss is not None else output
|
| 1149 |
+
|
| 1150 |
+
return TokenClassifierOutput(
|
| 1151 |
+
loss=loss,
|
| 1152 |
+
logits=logits,
|
| 1153 |
+
hidden_states=outputs.hidden_states,
|
| 1154 |
+
attentions=outputs.attentions,
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
+
|
| 1158 |
+
@auto_docstring
|
| 1159 |
+
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
| 1160 |
+
def __init__(self, config: AlbertConfig):
|
| 1161 |
+
super().__init__(config)
|
| 1162 |
+
self.num_labels = config.num_labels
|
| 1163 |
+
|
| 1164 |
+
self.albert = AlbertModel(config, add_pooling_layer=False)
|
| 1165 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 1166 |
+
|
| 1167 |
+
# Initialize weights and apply final processing
|
| 1168 |
+
self.post_init()
|
| 1169 |
+
|
| 1170 |
+
@auto_docstring
|
| 1171 |
+
def forward(
|
| 1172 |
+
self,
|
| 1173 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1174 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1175 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1176 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1177 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1178 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1179 |
+
start_positions: Optional[torch.LongTensor] = None,
|
| 1180 |
+
end_positions: Optional[torch.LongTensor] = None,
|
| 1181 |
+
output_attentions: Optional[bool] = None,
|
| 1182 |
+
output_hidden_states: Optional[bool] = None,
|
| 1183 |
+
return_dict: Optional[bool] = None,
|
| 1184 |
+
) -> Union[AlbertForPreTrainingOutput, tuple]:
|
| 1185 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1186 |
+
|
| 1187 |
+
outputs = self.albert(
|
| 1188 |
+
input_ids=input_ids,
|
| 1189 |
+
attention_mask=attention_mask,
|
| 1190 |
+
token_type_ids=token_type_ids,
|
| 1191 |
+
position_ids=position_ids,
|
| 1192 |
+
head_mask=head_mask,
|
| 1193 |
+
inputs_embeds=inputs_embeds,
|
| 1194 |
+
output_attentions=output_attentions,
|
| 1195 |
+
output_hidden_states=output_hidden_states,
|
| 1196 |
+
return_dict=return_dict,
|
| 1197 |
+
)
|
| 1198 |
+
|
| 1199 |
+
sequence_output = outputs[0]
|
| 1200 |
+
|
| 1201 |
+
logits: torch.Tensor = self.qa_outputs(sequence_output)
|
| 1202 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1203 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1204 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1205 |
+
|
| 1206 |
+
total_loss = None
|
| 1207 |
+
if start_positions is not None and end_positions is not None:
|
| 1208 |
+
# If we are on multi-GPU, split add a dimension
|
| 1209 |
+
if len(start_positions.size()) > 1:
|
| 1210 |
+
start_positions = start_positions.squeeze(-1)
|
| 1211 |
+
if len(end_positions.size()) > 1:
|
| 1212 |
+
end_positions = end_positions.squeeze(-1)
|
| 1213 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1214 |
+
ignored_index = start_logits.size(1)
|
| 1215 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1216 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1217 |
+
|
| 1218 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1219 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1220 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1221 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1222 |
+
|
| 1223 |
+
if not return_dict:
|
| 1224 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 1225 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1226 |
+
|
| 1227 |
+
return QuestionAnsweringModelOutput(
|
| 1228 |
+
loss=total_loss,
|
| 1229 |
+
start_logits=start_logits,
|
| 1230 |
+
end_logits=end_logits,
|
| 1231 |
+
hidden_states=outputs.hidden_states,
|
| 1232 |
+
attentions=outputs.attentions,
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
|
| 1236 |
+
@auto_docstring
|
| 1237 |
+
class AlbertForMultipleChoice(AlbertPreTrainedModel):
|
| 1238 |
+
def __init__(self, config: AlbertConfig):
|
| 1239 |
+
super().__init__(config)
|
| 1240 |
+
|
| 1241 |
+
self.albert = AlbertModel(config)
|
| 1242 |
+
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
| 1243 |
+
self.classifier = nn.Linear(config.hidden_size, 1)
|
| 1244 |
+
|
| 1245 |
+
# Initialize weights and apply final processing
|
| 1246 |
+
self.post_init()
|
| 1247 |
+
|
| 1248 |
+
@auto_docstring
|
| 1249 |
+
def forward(
|
| 1250 |
+
self,
|
| 1251 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1252 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1253 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1254 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1255 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1256 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1257 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1258 |
+
output_attentions: Optional[bool] = None,
|
| 1259 |
+
output_hidden_states: Optional[bool] = None,
|
| 1260 |
+
return_dict: Optional[bool] = None,
|
| 1261 |
+
) -> Union[AlbertForPreTrainingOutput, tuple]:
|
| 1262 |
+
r"""
|
| 1263 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
|
| 1264 |
+
Indices of input sequence tokens in the vocabulary.
|
| 1265 |
+
|
| 1266 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
|
| 1267 |
+
[`PreTrainedTokenizer.encode`] for details.
|
| 1268 |
+
|
| 1269 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1270 |
+
token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
|
| 1271 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 1272 |
+
1]`:
|
| 1273 |
+
|
| 1274 |
+
- 0 corresponds to a *sentence A* token,
|
| 1275 |
+
- 1 corresponds to a *sentence B* token.
|
| 1276 |
+
|
| 1277 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 1278 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
|
| 1279 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 1280 |
+
config.max_position_embeddings - 1]`.
|
| 1281 |
+
|
| 1282 |
+
[What are position IDs?](../glossary#position-ids)
|
| 1283 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
|
| 1284 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 1285 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 1286 |
+
model's internal embedding lookup matrix.
|
| 1287 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1288 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
| 1289 |
+
num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see
|
| 1290 |
+
*input_ids* above)
|
| 1291 |
+
"""
|
| 1292 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1293 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
| 1294 |
+
|
| 1295 |
+
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
| 1296 |
+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 1297 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
| 1298 |
+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
| 1299 |
+
inputs_embeds = (
|
| 1300 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
| 1301 |
+
if inputs_embeds is not None
|
| 1302 |
+
else None
|
| 1303 |
+
)
|
| 1304 |
+
outputs = self.albert(
|
| 1305 |
+
input_ids,
|
| 1306 |
+
attention_mask=attention_mask,
|
| 1307 |
+
token_type_ids=token_type_ids,
|
| 1308 |
+
position_ids=position_ids,
|
| 1309 |
+
head_mask=head_mask,
|
| 1310 |
+
inputs_embeds=inputs_embeds,
|
| 1311 |
+
output_attentions=output_attentions,
|
| 1312 |
+
output_hidden_states=output_hidden_states,
|
| 1313 |
+
return_dict=return_dict,
|
| 1314 |
+
)
|
| 1315 |
+
|
| 1316 |
+
pooled_output = outputs[1]
|
| 1317 |
+
|
| 1318 |
+
pooled_output = self.dropout(pooled_output)
|
| 1319 |
+
logits: torch.Tensor = self.classifier(pooled_output)
|
| 1320 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 1321 |
+
|
| 1322 |
+
loss = None
|
| 1323 |
+
if labels is not None:
|
| 1324 |
+
loss_fct = CrossEntropyLoss()
|
| 1325 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 1326 |
+
|
| 1327 |
+
if not return_dict:
|
| 1328 |
+
output = (reshaped_logits,) + outputs[2:]
|
| 1329 |
+
return ((loss,) + output) if loss is not None else output
|
| 1330 |
+
|
| 1331 |
+
return MultipleChoiceModelOutput(
|
| 1332 |
+
loss=loss,
|
| 1333 |
+
logits=reshaped_logits,
|
| 1334 |
+
hidden_states=outputs.hidden_states,
|
| 1335 |
+
attentions=outputs.attentions,
|
| 1336 |
+
)
|
| 1337 |
+
|
| 1338 |
+
|
| 1339 |
+
__all__ = [
|
| 1340 |
+
"load_tf_weights_in_albert",
|
| 1341 |
+
"AlbertPreTrainedModel",
|
| 1342 |
+
"AlbertModel",
|
| 1343 |
+
"AlbertForPreTraining",
|
| 1344 |
+
"AlbertForMaskedLM",
|
| 1345 |
+
"AlbertForSequenceClassification",
|
| 1346 |
+
"AlbertForTokenClassification",
|
| 1347 |
+
"AlbertForQuestionAnswering",
|
| 1348 |
+
"AlbertForMultipleChoice",
|
| 1349 |
+
]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_flax_albert.py
ADDED
|
@@ -0,0 +1,1132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 Google AI, Google Brain and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Callable, Optional
|
| 17 |
+
|
| 18 |
+
import flax
|
| 19 |
+
import flax.linen as nn
|
| 20 |
+
import jax
|
| 21 |
+
import jax.numpy as jnp
|
| 22 |
+
import numpy as np
|
| 23 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
| 24 |
+
from flax.linen.attention import dot_product_attention_weights
|
| 25 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
| 26 |
+
from jax import lax
|
| 27 |
+
|
| 28 |
+
from ...modeling_flax_outputs import (
|
| 29 |
+
FlaxBaseModelOutput,
|
| 30 |
+
FlaxBaseModelOutputWithPooling,
|
| 31 |
+
FlaxMaskedLMOutput,
|
| 32 |
+
FlaxMultipleChoiceModelOutput,
|
| 33 |
+
FlaxQuestionAnsweringModelOutput,
|
| 34 |
+
FlaxSequenceClassifierOutput,
|
| 35 |
+
FlaxTokenClassifierOutput,
|
| 36 |
+
)
|
| 37 |
+
from ...modeling_flax_utils import (
|
| 38 |
+
ACT2FN,
|
| 39 |
+
FlaxPreTrainedModel,
|
| 40 |
+
append_call_sample_docstring,
|
| 41 |
+
append_replace_return_docstrings,
|
| 42 |
+
overwrite_call_docstring,
|
| 43 |
+
)
|
| 44 |
+
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
| 45 |
+
from .configuration_albert import AlbertConfig
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__)
|
| 49 |
+
|
| 50 |
+
_CHECKPOINT_FOR_DOC = "albert/albert-base-v2"
|
| 51 |
+
_CONFIG_FOR_DOC = "AlbertConfig"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@flax.struct.dataclass
|
| 55 |
+
class FlaxAlbertForPreTrainingOutput(ModelOutput):
|
| 56 |
+
"""
|
| 57 |
+
Output type of [`FlaxAlbertForPreTraining`].
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 61 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 62 |
+
sop_logits (`jnp.ndarray` of shape `(batch_size, 2)`):
|
| 63 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
| 64 |
+
before SoftMax).
|
| 65 |
+
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 66 |
+
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 67 |
+
`(batch_size, sequence_length, hidden_size)`.
|
| 68 |
+
|
| 69 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 70 |
+
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 71 |
+
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 72 |
+
sequence_length)`.
|
| 73 |
+
|
| 74 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 75 |
+
heads.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
prediction_logits: jnp.ndarray = None
|
| 79 |
+
sop_logits: jnp.ndarray = None
|
| 80 |
+
hidden_states: Optional[tuple[jnp.ndarray]] = None
|
| 81 |
+
attentions: Optional[tuple[jnp.ndarray]] = None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
ALBERT_START_DOCSTRING = r"""
|
| 85 |
+
|
| 86 |
+
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 87 |
+
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
|
| 88 |
+
|
| 89 |
+
This model is also a
|
| 90 |
+
[flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
|
| 91 |
+
a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
|
| 92 |
+
behavior.
|
| 93 |
+
|
| 94 |
+
Finally, this model supports inherent JAX features such as:
|
| 95 |
+
|
| 96 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
| 97 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
| 98 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
| 99 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
| 100 |
+
|
| 101 |
+
Parameters:
|
| 102 |
+
config ([`AlbertConfig`]): Model configuration class with all the parameters of the model.
|
| 103 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 104 |
+
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
| 105 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
| 106 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
| 107 |
+
`jax.numpy.bfloat16` (on TPUs).
|
| 108 |
+
|
| 109 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
| 110 |
+
specified all the computation will be performed with the given `dtype`.
|
| 111 |
+
|
| 112 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
| 113 |
+
parameters.**
|
| 114 |
+
|
| 115 |
+
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
| 116 |
+
[`~FlaxPreTrainedModel.to_bf16`].
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
ALBERT_INPUTS_DOCSTRING = r"""
|
| 120 |
+
Args:
|
| 121 |
+
input_ids (`numpy.ndarray` of shape `({0})`):
|
| 122 |
+
Indices of input sequence tokens in the vocabulary.
|
| 123 |
+
|
| 124 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 125 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 126 |
+
|
| 127 |
+
[What are input IDs?](../glossary#input-ids)
|
| 128 |
+
attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
|
| 129 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 130 |
+
|
| 131 |
+
- 1 for tokens that are **not masked**,
|
| 132 |
+
- 0 for tokens that are **masked**.
|
| 133 |
+
|
| 134 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 135 |
+
token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
|
| 136 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 137 |
+
1]`:
|
| 138 |
+
|
| 139 |
+
- 0 corresponds to a *sentence A* token,
|
| 140 |
+
- 1 corresponds to a *sentence B* token.
|
| 141 |
+
|
| 142 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 143 |
+
position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
|
| 144 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 145 |
+
config.max_position_embeddings - 1]`.
|
| 146 |
+
return_dict (`bool`, *optional*):
|
| 147 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 148 |
+
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class FlaxAlbertEmbeddings(nn.Module):
|
| 153 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 154 |
+
|
| 155 |
+
config: AlbertConfig
|
| 156 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 157 |
+
|
| 158 |
+
def setup(self):
|
| 159 |
+
self.word_embeddings = nn.Embed(
|
| 160 |
+
self.config.vocab_size,
|
| 161 |
+
self.config.embedding_size,
|
| 162 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
| 163 |
+
)
|
| 164 |
+
self.position_embeddings = nn.Embed(
|
| 165 |
+
self.config.max_position_embeddings,
|
| 166 |
+
self.config.embedding_size,
|
| 167 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
| 168 |
+
)
|
| 169 |
+
self.token_type_embeddings = nn.Embed(
|
| 170 |
+
self.config.type_vocab_size,
|
| 171 |
+
self.config.embedding_size,
|
| 172 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
| 173 |
+
)
|
| 174 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 175 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 176 |
+
|
| 177 |
+
def __call__(self, input_ids, token_type_ids, position_ids, deterministic: bool = True):
|
| 178 |
+
# Embed
|
| 179 |
+
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
|
| 180 |
+
position_embeds = self.position_embeddings(position_ids.astype("i4"))
|
| 181 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
|
| 182 |
+
|
| 183 |
+
# Sum all embeddings
|
| 184 |
+
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
|
| 185 |
+
|
| 186 |
+
# Layer Norm
|
| 187 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 188 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 189 |
+
return hidden_states
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class FlaxAlbertSelfAttention(nn.Module):
|
| 193 |
+
config: AlbertConfig
|
| 194 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 195 |
+
|
| 196 |
+
def setup(self):
|
| 197 |
+
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
| 198 |
+
raise ValueError(
|
| 199 |
+
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
|
| 200 |
+
" : {self.config.num_attention_heads}"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
self.query = nn.Dense(
|
| 204 |
+
self.config.hidden_size,
|
| 205 |
+
dtype=self.dtype,
|
| 206 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 207 |
+
)
|
| 208 |
+
self.key = nn.Dense(
|
| 209 |
+
self.config.hidden_size,
|
| 210 |
+
dtype=self.dtype,
|
| 211 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 212 |
+
)
|
| 213 |
+
self.value = nn.Dense(
|
| 214 |
+
self.config.hidden_size,
|
| 215 |
+
dtype=self.dtype,
|
| 216 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 217 |
+
)
|
| 218 |
+
self.dense = nn.Dense(
|
| 219 |
+
self.config.hidden_size,
|
| 220 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 221 |
+
dtype=self.dtype,
|
| 222 |
+
)
|
| 223 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 224 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 225 |
+
|
| 226 |
+
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
| 227 |
+
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
| 228 |
+
|
| 229 |
+
query_states = self.query(hidden_states).reshape(
|
| 230 |
+
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
| 231 |
+
)
|
| 232 |
+
value_states = self.value(hidden_states).reshape(
|
| 233 |
+
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
| 234 |
+
)
|
| 235 |
+
key_states = self.key(hidden_states).reshape(
|
| 236 |
+
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Convert the boolean attention mask to an attention bias.
|
| 240 |
+
if attention_mask is not None:
|
| 241 |
+
# attention mask in the form of attention bias
|
| 242 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
| 243 |
+
attention_bias = lax.select(
|
| 244 |
+
attention_mask > 0,
|
| 245 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
| 246 |
+
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
| 247 |
+
)
|
| 248 |
+
else:
|
| 249 |
+
attention_bias = None
|
| 250 |
+
|
| 251 |
+
dropout_rng = None
|
| 252 |
+
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
|
| 253 |
+
dropout_rng = self.make_rng("dropout")
|
| 254 |
+
|
| 255 |
+
attn_weights = dot_product_attention_weights(
|
| 256 |
+
query_states,
|
| 257 |
+
key_states,
|
| 258 |
+
bias=attention_bias,
|
| 259 |
+
dropout_rng=dropout_rng,
|
| 260 |
+
dropout_rate=self.config.attention_probs_dropout_prob,
|
| 261 |
+
broadcast_dropout=True,
|
| 262 |
+
deterministic=deterministic,
|
| 263 |
+
dtype=self.dtype,
|
| 264 |
+
precision=None,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
| 268 |
+
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
| 269 |
+
|
| 270 |
+
projected_attn_output = self.dense(attn_output)
|
| 271 |
+
projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic)
|
| 272 |
+
layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states)
|
| 273 |
+
outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,)
|
| 274 |
+
return outputs
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class FlaxAlbertLayer(nn.Module):
|
| 278 |
+
config: AlbertConfig
|
| 279 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 280 |
+
|
| 281 |
+
def setup(self):
|
| 282 |
+
self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype)
|
| 283 |
+
self.ffn = nn.Dense(
|
| 284 |
+
self.config.intermediate_size,
|
| 285 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 286 |
+
dtype=self.dtype,
|
| 287 |
+
)
|
| 288 |
+
self.activation = ACT2FN[self.config.hidden_act]
|
| 289 |
+
self.ffn_output = nn.Dense(
|
| 290 |
+
self.config.hidden_size,
|
| 291 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 292 |
+
dtype=self.dtype,
|
| 293 |
+
)
|
| 294 |
+
self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 295 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 296 |
+
|
| 297 |
+
def __call__(
|
| 298 |
+
self,
|
| 299 |
+
hidden_states,
|
| 300 |
+
attention_mask,
|
| 301 |
+
deterministic: bool = True,
|
| 302 |
+
output_attentions: bool = False,
|
| 303 |
+
):
|
| 304 |
+
attention_outputs = self.attention(
|
| 305 |
+
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
| 306 |
+
)
|
| 307 |
+
attention_output = attention_outputs[0]
|
| 308 |
+
ffn_output = self.ffn(attention_output)
|
| 309 |
+
ffn_output = self.activation(ffn_output)
|
| 310 |
+
ffn_output = self.ffn_output(ffn_output)
|
| 311 |
+
ffn_output = self.dropout(ffn_output, deterministic=deterministic)
|
| 312 |
+
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)
|
| 313 |
+
|
| 314 |
+
outputs = (hidden_states,)
|
| 315 |
+
|
| 316 |
+
if output_attentions:
|
| 317 |
+
outputs += (attention_outputs[1],)
|
| 318 |
+
return outputs
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class FlaxAlbertLayerCollection(nn.Module):
|
| 322 |
+
config: AlbertConfig
|
| 323 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 324 |
+
|
| 325 |
+
def setup(self):
|
| 326 |
+
self.layers = [
|
| 327 |
+
FlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num)
|
| 328 |
+
]
|
| 329 |
+
|
| 330 |
+
def __call__(
|
| 331 |
+
self,
|
| 332 |
+
hidden_states,
|
| 333 |
+
attention_mask,
|
| 334 |
+
deterministic: bool = True,
|
| 335 |
+
output_attentions: bool = False,
|
| 336 |
+
output_hidden_states: bool = False,
|
| 337 |
+
):
|
| 338 |
+
layer_hidden_states = ()
|
| 339 |
+
layer_attentions = ()
|
| 340 |
+
|
| 341 |
+
for layer_index, albert_layer in enumerate(self.layers):
|
| 342 |
+
layer_output = albert_layer(
|
| 343 |
+
hidden_states,
|
| 344 |
+
attention_mask,
|
| 345 |
+
deterministic=deterministic,
|
| 346 |
+
output_attentions=output_attentions,
|
| 347 |
+
)
|
| 348 |
+
hidden_states = layer_output[0]
|
| 349 |
+
|
| 350 |
+
if output_attentions:
|
| 351 |
+
layer_attentions = layer_attentions + (layer_output[1],)
|
| 352 |
+
|
| 353 |
+
if output_hidden_states:
|
| 354 |
+
layer_hidden_states = layer_hidden_states + (hidden_states,)
|
| 355 |
+
|
| 356 |
+
outputs = (hidden_states,)
|
| 357 |
+
if output_hidden_states:
|
| 358 |
+
outputs = outputs + (layer_hidden_states,)
|
| 359 |
+
if output_attentions:
|
| 360 |
+
outputs = outputs + (layer_attentions,)
|
| 361 |
+
return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class FlaxAlbertLayerCollections(nn.Module):
|
| 365 |
+
config: AlbertConfig
|
| 366 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 367 |
+
layer_index: Optional[str] = None
|
| 368 |
+
|
| 369 |
+
def setup(self):
|
| 370 |
+
self.albert_layers = FlaxAlbertLayerCollection(self.config, dtype=self.dtype)
|
| 371 |
+
|
| 372 |
+
def __call__(
|
| 373 |
+
self,
|
| 374 |
+
hidden_states,
|
| 375 |
+
attention_mask,
|
| 376 |
+
deterministic: bool = True,
|
| 377 |
+
output_attentions: bool = False,
|
| 378 |
+
output_hidden_states: bool = False,
|
| 379 |
+
):
|
| 380 |
+
outputs = self.albert_layers(
|
| 381 |
+
hidden_states,
|
| 382 |
+
attention_mask,
|
| 383 |
+
deterministic=deterministic,
|
| 384 |
+
output_attentions=output_attentions,
|
| 385 |
+
output_hidden_states=output_hidden_states,
|
| 386 |
+
)
|
| 387 |
+
return outputs
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class FlaxAlbertLayerGroups(nn.Module):
|
| 391 |
+
config: AlbertConfig
|
| 392 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 393 |
+
|
| 394 |
+
def setup(self):
|
| 395 |
+
self.layers = [
|
| 396 |
+
FlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype)
|
| 397 |
+
for i in range(self.config.num_hidden_groups)
|
| 398 |
+
]
|
| 399 |
+
|
| 400 |
+
def __call__(
|
| 401 |
+
self,
|
| 402 |
+
hidden_states,
|
| 403 |
+
attention_mask,
|
| 404 |
+
deterministic: bool = True,
|
| 405 |
+
output_attentions: bool = False,
|
| 406 |
+
output_hidden_states: bool = False,
|
| 407 |
+
return_dict: bool = True,
|
| 408 |
+
):
|
| 409 |
+
all_attentions = () if output_attentions else None
|
| 410 |
+
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
| 411 |
+
|
| 412 |
+
for i in range(self.config.num_hidden_layers):
|
| 413 |
+
# Index of the hidden group
|
| 414 |
+
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
|
| 415 |
+
layer_group_output = self.layers[group_idx](
|
| 416 |
+
hidden_states,
|
| 417 |
+
attention_mask,
|
| 418 |
+
deterministic=deterministic,
|
| 419 |
+
output_attentions=output_attentions,
|
| 420 |
+
output_hidden_states=output_hidden_states,
|
| 421 |
+
)
|
| 422 |
+
hidden_states = layer_group_output[0]
|
| 423 |
+
|
| 424 |
+
if output_attentions:
|
| 425 |
+
all_attentions = all_attentions + layer_group_output[-1]
|
| 426 |
+
|
| 427 |
+
if output_hidden_states:
|
| 428 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 429 |
+
|
| 430 |
+
if not return_dict:
|
| 431 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
| 432 |
+
return FlaxBaseModelOutput(
|
| 433 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class FlaxAlbertEncoder(nn.Module):
|
| 438 |
+
config: AlbertConfig
|
| 439 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 440 |
+
|
| 441 |
+
def setup(self):
|
| 442 |
+
self.embedding_hidden_mapping_in = nn.Dense(
|
| 443 |
+
self.config.hidden_size,
|
| 444 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 445 |
+
dtype=self.dtype,
|
| 446 |
+
)
|
| 447 |
+
self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype)
|
| 448 |
+
|
| 449 |
+
def __call__(
|
| 450 |
+
self,
|
| 451 |
+
hidden_states,
|
| 452 |
+
attention_mask,
|
| 453 |
+
deterministic: bool = True,
|
| 454 |
+
output_attentions: bool = False,
|
| 455 |
+
output_hidden_states: bool = False,
|
| 456 |
+
return_dict: bool = True,
|
| 457 |
+
):
|
| 458 |
+
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
| 459 |
+
return self.albert_layer_groups(
|
| 460 |
+
hidden_states,
|
| 461 |
+
attention_mask,
|
| 462 |
+
deterministic=deterministic,
|
| 463 |
+
output_attentions=output_attentions,
|
| 464 |
+
output_hidden_states=output_hidden_states,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class FlaxAlbertOnlyMLMHead(nn.Module):
|
| 469 |
+
config: AlbertConfig
|
| 470 |
+
dtype: jnp.dtype = jnp.float32
|
| 471 |
+
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
| 472 |
+
|
| 473 |
+
def setup(self):
|
| 474 |
+
self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype)
|
| 475 |
+
self.activation = ACT2FN[self.config.hidden_act]
|
| 476 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 477 |
+
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
|
| 478 |
+
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
|
| 479 |
+
|
| 480 |
+
def __call__(self, hidden_states, shared_embedding=None):
|
| 481 |
+
hidden_states = self.dense(hidden_states)
|
| 482 |
+
hidden_states = self.activation(hidden_states)
|
| 483 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 484 |
+
|
| 485 |
+
if shared_embedding is not None:
|
| 486 |
+
hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
| 487 |
+
else:
|
| 488 |
+
hidden_states = self.decoder(hidden_states)
|
| 489 |
+
|
| 490 |
+
hidden_states += self.bias
|
| 491 |
+
return hidden_states
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
class FlaxAlbertSOPHead(nn.Module):
|
| 495 |
+
config: AlbertConfig
|
| 496 |
+
dtype: jnp.dtype = jnp.float32
|
| 497 |
+
|
| 498 |
+
def setup(self):
|
| 499 |
+
self.dropout = nn.Dropout(self.config.classifier_dropout_prob)
|
| 500 |
+
self.classifier = nn.Dense(2, dtype=self.dtype)
|
| 501 |
+
|
| 502 |
+
def __call__(self, pooled_output, deterministic=True):
|
| 503 |
+
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
| 504 |
+
logits = self.classifier(pooled_output)
|
| 505 |
+
return logits
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel):
|
| 509 |
+
"""
|
| 510 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 511 |
+
models.
|
| 512 |
+
"""
|
| 513 |
+
|
| 514 |
+
config_class = AlbertConfig
|
| 515 |
+
base_model_prefix = "albert"
|
| 516 |
+
module_class: nn.Module = None
|
| 517 |
+
|
| 518 |
+
def __init__(
|
| 519 |
+
self,
|
| 520 |
+
config: AlbertConfig,
|
| 521 |
+
input_shape: tuple = (1, 1),
|
| 522 |
+
seed: int = 0,
|
| 523 |
+
dtype: jnp.dtype = jnp.float32,
|
| 524 |
+
_do_init: bool = True,
|
| 525 |
+
**kwargs,
|
| 526 |
+
):
|
| 527 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
| 528 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
| 529 |
+
|
| 530 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict:
|
| 531 |
+
# init input tensors
|
| 532 |
+
input_ids = jnp.zeros(input_shape, dtype="i4")
|
| 533 |
+
token_type_ids = jnp.zeros_like(input_ids)
|
| 534 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
| 535 |
+
attention_mask = jnp.ones_like(input_ids)
|
| 536 |
+
|
| 537 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
| 538 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
| 539 |
+
|
| 540 |
+
random_params = self.module.init(
|
| 541 |
+
rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False
|
| 542 |
+
)["params"]
|
| 543 |
+
|
| 544 |
+
if params is not None:
|
| 545 |
+
random_params = flatten_dict(unfreeze(random_params))
|
| 546 |
+
params = flatten_dict(unfreeze(params))
|
| 547 |
+
for missing_key in self._missing_keys:
|
| 548 |
+
params[missing_key] = random_params[missing_key]
|
| 549 |
+
self._missing_keys = set()
|
| 550 |
+
return freeze(unflatten_dict(params))
|
| 551 |
+
else:
|
| 552 |
+
return random_params
|
| 553 |
+
|
| 554 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 555 |
+
def __call__(
|
| 556 |
+
self,
|
| 557 |
+
input_ids,
|
| 558 |
+
attention_mask=None,
|
| 559 |
+
token_type_ids=None,
|
| 560 |
+
position_ids=None,
|
| 561 |
+
params: Optional[dict] = None,
|
| 562 |
+
dropout_rng: jax.random.PRNGKey = None,
|
| 563 |
+
train: bool = False,
|
| 564 |
+
output_attentions: Optional[bool] = None,
|
| 565 |
+
output_hidden_states: Optional[bool] = None,
|
| 566 |
+
return_dict: Optional[bool] = None,
|
| 567 |
+
):
|
| 568 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 569 |
+
output_hidden_states = (
|
| 570 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 571 |
+
)
|
| 572 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 573 |
+
|
| 574 |
+
# init input tensors if not passed
|
| 575 |
+
if token_type_ids is None:
|
| 576 |
+
token_type_ids = jnp.zeros_like(input_ids)
|
| 577 |
+
|
| 578 |
+
if position_ids is None:
|
| 579 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
| 580 |
+
|
| 581 |
+
if attention_mask is None:
|
| 582 |
+
attention_mask = jnp.ones_like(input_ids)
|
| 583 |
+
|
| 584 |
+
# Handle any PRNG if needed
|
| 585 |
+
rngs = {}
|
| 586 |
+
if dropout_rng is not None:
|
| 587 |
+
rngs["dropout"] = dropout_rng
|
| 588 |
+
|
| 589 |
+
return self.module.apply(
|
| 590 |
+
{"params": params or self.params},
|
| 591 |
+
jnp.array(input_ids, dtype="i4"),
|
| 592 |
+
jnp.array(attention_mask, dtype="i4"),
|
| 593 |
+
jnp.array(token_type_ids, dtype="i4"),
|
| 594 |
+
jnp.array(position_ids, dtype="i4"),
|
| 595 |
+
not train,
|
| 596 |
+
output_attentions,
|
| 597 |
+
output_hidden_states,
|
| 598 |
+
return_dict,
|
| 599 |
+
rngs=rngs,
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class FlaxAlbertModule(nn.Module):
|
| 604 |
+
config: AlbertConfig
|
| 605 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 606 |
+
add_pooling_layer: bool = True
|
| 607 |
+
|
| 608 |
+
def setup(self):
|
| 609 |
+
self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype)
|
| 610 |
+
self.encoder = FlaxAlbertEncoder(self.config, dtype=self.dtype)
|
| 611 |
+
if self.add_pooling_layer:
|
| 612 |
+
self.pooler = nn.Dense(
|
| 613 |
+
self.config.hidden_size,
|
| 614 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 615 |
+
dtype=self.dtype,
|
| 616 |
+
name="pooler",
|
| 617 |
+
)
|
| 618 |
+
self.pooler_activation = nn.tanh
|
| 619 |
+
else:
|
| 620 |
+
self.pooler = None
|
| 621 |
+
self.pooler_activation = None
|
| 622 |
+
|
| 623 |
+
def __call__(
|
| 624 |
+
self,
|
| 625 |
+
input_ids,
|
| 626 |
+
attention_mask,
|
| 627 |
+
token_type_ids: Optional[np.ndarray] = None,
|
| 628 |
+
position_ids: Optional[np.ndarray] = None,
|
| 629 |
+
deterministic: bool = True,
|
| 630 |
+
output_attentions: bool = False,
|
| 631 |
+
output_hidden_states: bool = False,
|
| 632 |
+
return_dict: bool = True,
|
| 633 |
+
):
|
| 634 |
+
# make sure `token_type_ids` is correctly initialized when not passed
|
| 635 |
+
if token_type_ids is None:
|
| 636 |
+
token_type_ids = jnp.zeros_like(input_ids)
|
| 637 |
+
|
| 638 |
+
# make sure `position_ids` is correctly initialized when not passed
|
| 639 |
+
if position_ids is None:
|
| 640 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
| 641 |
+
|
| 642 |
+
hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic)
|
| 643 |
+
|
| 644 |
+
outputs = self.encoder(
|
| 645 |
+
hidden_states,
|
| 646 |
+
attention_mask,
|
| 647 |
+
deterministic=deterministic,
|
| 648 |
+
output_attentions=output_attentions,
|
| 649 |
+
output_hidden_states=output_hidden_states,
|
| 650 |
+
return_dict=return_dict,
|
| 651 |
+
)
|
| 652 |
+
hidden_states = outputs[0]
|
| 653 |
+
if self.add_pooling_layer:
|
| 654 |
+
pooled = self.pooler(hidden_states[:, 0])
|
| 655 |
+
pooled = self.pooler_activation(pooled)
|
| 656 |
+
else:
|
| 657 |
+
pooled = None
|
| 658 |
+
|
| 659 |
+
if not return_dict:
|
| 660 |
+
# if pooled is None, don't return it
|
| 661 |
+
if pooled is None:
|
| 662 |
+
return (hidden_states,) + outputs[1:]
|
| 663 |
+
return (hidden_states, pooled) + outputs[1:]
|
| 664 |
+
|
| 665 |
+
return FlaxBaseModelOutputWithPooling(
|
| 666 |
+
last_hidden_state=hidden_states,
|
| 667 |
+
pooler_output=pooled,
|
| 668 |
+
hidden_states=outputs.hidden_states,
|
| 669 |
+
attentions=outputs.attentions,
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
@add_start_docstrings(
|
| 674 |
+
"The bare Albert Model transformer outputting raw hidden-states without any specific head on top.",
|
| 675 |
+
ALBERT_START_DOCSTRING,
|
| 676 |
+
)
|
| 677 |
+
class FlaxAlbertModel(FlaxAlbertPreTrainedModel):
|
| 678 |
+
module_class = FlaxAlbertModule
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
append_call_sample_docstring(FlaxAlbertModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
class FlaxAlbertForPreTrainingModule(nn.Module):
|
| 685 |
+
config: AlbertConfig
|
| 686 |
+
dtype: jnp.dtype = jnp.float32
|
| 687 |
+
|
| 688 |
+
def setup(self):
|
| 689 |
+
self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
|
| 690 |
+
self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
|
| 691 |
+
self.sop_classifier = FlaxAlbertSOPHead(config=self.config, dtype=self.dtype)
|
| 692 |
+
|
| 693 |
+
def __call__(
|
| 694 |
+
self,
|
| 695 |
+
input_ids,
|
| 696 |
+
attention_mask,
|
| 697 |
+
token_type_ids,
|
| 698 |
+
position_ids,
|
| 699 |
+
deterministic: bool = True,
|
| 700 |
+
output_attentions: bool = False,
|
| 701 |
+
output_hidden_states: bool = False,
|
| 702 |
+
return_dict: bool = True,
|
| 703 |
+
):
|
| 704 |
+
# Model
|
| 705 |
+
outputs = self.albert(
|
| 706 |
+
input_ids,
|
| 707 |
+
attention_mask,
|
| 708 |
+
token_type_ids,
|
| 709 |
+
position_ids,
|
| 710 |
+
deterministic=deterministic,
|
| 711 |
+
output_attentions=output_attentions,
|
| 712 |
+
output_hidden_states=output_hidden_states,
|
| 713 |
+
return_dict=return_dict,
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
if self.config.tie_word_embeddings:
|
| 717 |
+
shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
| 718 |
+
else:
|
| 719 |
+
shared_embedding = None
|
| 720 |
+
|
| 721 |
+
hidden_states = outputs[0]
|
| 722 |
+
pooled_output = outputs[1]
|
| 723 |
+
|
| 724 |
+
prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
|
| 725 |
+
sop_scores = self.sop_classifier(pooled_output, deterministic=deterministic)
|
| 726 |
+
|
| 727 |
+
if not return_dict:
|
| 728 |
+
return (prediction_scores, sop_scores) + outputs[2:]
|
| 729 |
+
|
| 730 |
+
return FlaxAlbertForPreTrainingOutput(
|
| 731 |
+
prediction_logits=prediction_scores,
|
| 732 |
+
sop_logits=sop_scores,
|
| 733 |
+
hidden_states=outputs.hidden_states,
|
| 734 |
+
attentions=outputs.attentions,
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
@add_start_docstrings(
|
| 739 |
+
"""
|
| 740 |
+
Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
|
| 741 |
+
`sentence order prediction (classification)` head.
|
| 742 |
+
""",
|
| 743 |
+
ALBERT_START_DOCSTRING,
|
| 744 |
+
)
|
| 745 |
+
class FlaxAlbertForPreTraining(FlaxAlbertPreTrainedModel):
|
| 746 |
+
module_class = FlaxAlbertForPreTrainingModule
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING = """
|
| 750 |
+
Returns:
|
| 751 |
+
|
| 752 |
+
Example:
|
| 753 |
+
|
| 754 |
+
```python
|
| 755 |
+
>>> from transformers import AutoTokenizer, FlaxAlbertForPreTraining
|
| 756 |
+
|
| 757 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
|
| 758 |
+
>>> model = FlaxAlbertForPreTraining.from_pretrained("albert/albert-base-v2")
|
| 759 |
+
|
| 760 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
|
| 761 |
+
>>> outputs = model(**inputs)
|
| 762 |
+
|
| 763 |
+
>>> prediction_logits = outputs.prediction_logits
|
| 764 |
+
>>> seq_relationship_logits = outputs.sop_logits
|
| 765 |
+
```
|
| 766 |
+
"""
|
| 767 |
+
|
| 768 |
+
overwrite_call_docstring(
|
| 769 |
+
FlaxAlbertForPreTraining,
|
| 770 |
+
ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING,
|
| 771 |
+
)
|
| 772 |
+
append_replace_return_docstrings(
|
| 773 |
+
FlaxAlbertForPreTraining, output_type=FlaxAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
class FlaxAlbertForMaskedLMModule(nn.Module):
|
| 778 |
+
config: AlbertConfig
|
| 779 |
+
dtype: jnp.dtype = jnp.float32
|
| 780 |
+
|
| 781 |
+
def setup(self):
|
| 782 |
+
self.albert = FlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
| 783 |
+
self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
|
| 784 |
+
|
| 785 |
+
def __call__(
|
| 786 |
+
self,
|
| 787 |
+
input_ids,
|
| 788 |
+
attention_mask,
|
| 789 |
+
token_type_ids,
|
| 790 |
+
position_ids,
|
| 791 |
+
deterministic: bool = True,
|
| 792 |
+
output_attentions: bool = False,
|
| 793 |
+
output_hidden_states: bool = False,
|
| 794 |
+
return_dict: bool = True,
|
| 795 |
+
):
|
| 796 |
+
# Model
|
| 797 |
+
outputs = self.albert(
|
| 798 |
+
input_ids,
|
| 799 |
+
attention_mask,
|
| 800 |
+
token_type_ids,
|
| 801 |
+
position_ids,
|
| 802 |
+
deterministic=deterministic,
|
| 803 |
+
output_attentions=output_attentions,
|
| 804 |
+
output_hidden_states=output_hidden_states,
|
| 805 |
+
return_dict=return_dict,
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
hidden_states = outputs[0]
|
| 809 |
+
if self.config.tie_word_embeddings:
|
| 810 |
+
shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
| 811 |
+
else:
|
| 812 |
+
shared_embedding = None
|
| 813 |
+
|
| 814 |
+
# Compute the prediction scores
|
| 815 |
+
logits = self.predictions(hidden_states, shared_embedding=shared_embedding)
|
| 816 |
+
|
| 817 |
+
if not return_dict:
|
| 818 |
+
return (logits,) + outputs[1:]
|
| 819 |
+
|
| 820 |
+
return FlaxMaskedLMOutput(
|
| 821 |
+
logits=logits,
|
| 822 |
+
hidden_states=outputs.hidden_states,
|
| 823 |
+
attentions=outputs.attentions,
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
@add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING)
|
| 828 |
+
class FlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel):
|
| 829 |
+
module_class = FlaxAlbertForMaskedLMModule
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
append_call_sample_docstring(
|
| 833 |
+
FlaxAlbertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC, revision="refs/pr/11"
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
class FlaxAlbertForSequenceClassificationModule(nn.Module):
|
| 838 |
+
config: AlbertConfig
|
| 839 |
+
dtype: jnp.dtype = jnp.float32
|
| 840 |
+
|
| 841 |
+
def setup(self):
|
| 842 |
+
self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
|
| 843 |
+
classifier_dropout = (
|
| 844 |
+
self.config.classifier_dropout_prob
|
| 845 |
+
if self.config.classifier_dropout_prob is not None
|
| 846 |
+
else self.config.hidden_dropout_prob
|
| 847 |
+
)
|
| 848 |
+
self.dropout = nn.Dropout(rate=classifier_dropout)
|
| 849 |
+
self.classifier = nn.Dense(
|
| 850 |
+
self.config.num_labels,
|
| 851 |
+
dtype=self.dtype,
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
def __call__(
|
| 855 |
+
self,
|
| 856 |
+
input_ids,
|
| 857 |
+
attention_mask,
|
| 858 |
+
token_type_ids,
|
| 859 |
+
position_ids,
|
| 860 |
+
deterministic: bool = True,
|
| 861 |
+
output_attentions: bool = False,
|
| 862 |
+
output_hidden_states: bool = False,
|
| 863 |
+
return_dict: bool = True,
|
| 864 |
+
):
|
| 865 |
+
# Model
|
| 866 |
+
outputs = self.albert(
|
| 867 |
+
input_ids,
|
| 868 |
+
attention_mask,
|
| 869 |
+
token_type_ids,
|
| 870 |
+
position_ids,
|
| 871 |
+
deterministic=deterministic,
|
| 872 |
+
output_attentions=output_attentions,
|
| 873 |
+
output_hidden_states=output_hidden_states,
|
| 874 |
+
return_dict=return_dict,
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
pooled_output = outputs[1]
|
| 878 |
+
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
| 879 |
+
logits = self.classifier(pooled_output)
|
| 880 |
+
|
| 881 |
+
if not return_dict:
|
| 882 |
+
return (logits,) + outputs[2:]
|
| 883 |
+
|
| 884 |
+
return FlaxSequenceClassifierOutput(
|
| 885 |
+
logits=logits,
|
| 886 |
+
hidden_states=outputs.hidden_states,
|
| 887 |
+
attentions=outputs.attentions,
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
@add_start_docstrings(
|
| 892 |
+
"""
|
| 893 |
+
Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
| 894 |
+
output) e.g. for GLUE tasks.
|
| 895 |
+
""",
|
| 896 |
+
ALBERT_START_DOCSTRING,
|
| 897 |
+
)
|
| 898 |
+
class FlaxAlbertForSequenceClassification(FlaxAlbertPreTrainedModel):
|
| 899 |
+
module_class = FlaxAlbertForSequenceClassificationModule
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
append_call_sample_docstring(
|
| 903 |
+
FlaxAlbertForSequenceClassification,
|
| 904 |
+
_CHECKPOINT_FOR_DOC,
|
| 905 |
+
FlaxSequenceClassifierOutput,
|
| 906 |
+
_CONFIG_FOR_DOC,
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
class FlaxAlbertForMultipleChoiceModule(nn.Module):
|
| 911 |
+
config: AlbertConfig
|
| 912 |
+
dtype: jnp.dtype = jnp.float32
|
| 913 |
+
|
| 914 |
+
def setup(self):
|
| 915 |
+
self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype)
|
| 916 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 917 |
+
self.classifier = nn.Dense(1, dtype=self.dtype)
|
| 918 |
+
|
| 919 |
+
def __call__(
|
| 920 |
+
self,
|
| 921 |
+
input_ids,
|
| 922 |
+
attention_mask,
|
| 923 |
+
token_type_ids,
|
| 924 |
+
position_ids,
|
| 925 |
+
deterministic: bool = True,
|
| 926 |
+
output_attentions: bool = False,
|
| 927 |
+
output_hidden_states: bool = False,
|
| 928 |
+
return_dict: bool = True,
|
| 929 |
+
):
|
| 930 |
+
num_choices = input_ids.shape[1]
|
| 931 |
+
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
|
| 932 |
+
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
|
| 933 |
+
token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
|
| 934 |
+
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
|
| 935 |
+
|
| 936 |
+
# Model
|
| 937 |
+
outputs = self.albert(
|
| 938 |
+
input_ids,
|
| 939 |
+
attention_mask,
|
| 940 |
+
token_type_ids,
|
| 941 |
+
position_ids,
|
| 942 |
+
deterministic=deterministic,
|
| 943 |
+
output_attentions=output_attentions,
|
| 944 |
+
output_hidden_states=output_hidden_states,
|
| 945 |
+
return_dict=return_dict,
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
pooled_output = outputs[1]
|
| 949 |
+
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
| 950 |
+
logits = self.classifier(pooled_output)
|
| 951 |
+
|
| 952 |
+
reshaped_logits = logits.reshape(-1, num_choices)
|
| 953 |
+
|
| 954 |
+
if not return_dict:
|
| 955 |
+
return (reshaped_logits,) + outputs[2:]
|
| 956 |
+
|
| 957 |
+
return FlaxMultipleChoiceModelOutput(
|
| 958 |
+
logits=reshaped_logits,
|
| 959 |
+
hidden_states=outputs.hidden_states,
|
| 960 |
+
attentions=outputs.attentions,
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
|
| 964 |
+
@add_start_docstrings(
|
| 965 |
+
"""
|
| 966 |
+
Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
| 967 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
| 968 |
+
""",
|
| 969 |
+
ALBERT_START_DOCSTRING,
|
| 970 |
+
)
|
| 971 |
+
class FlaxAlbertForMultipleChoice(FlaxAlbertPreTrainedModel):
|
| 972 |
+
module_class = FlaxAlbertForMultipleChoiceModule
|
| 973 |
+
|
| 974 |
+
|
| 975 |
+
overwrite_call_docstring(
|
| 976 |
+
FlaxAlbertForMultipleChoice, ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
| 977 |
+
)
|
| 978 |
+
append_call_sample_docstring(
|
| 979 |
+
FlaxAlbertForMultipleChoice,
|
| 980 |
+
_CHECKPOINT_FOR_DOC,
|
| 981 |
+
FlaxMultipleChoiceModelOutput,
|
| 982 |
+
_CONFIG_FOR_DOC,
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
|
| 986 |
+
class FlaxAlbertForTokenClassificationModule(nn.Module):
|
| 987 |
+
config: AlbertConfig
|
| 988 |
+
dtype: jnp.dtype = jnp.float32
|
| 989 |
+
|
| 990 |
+
def setup(self):
|
| 991 |
+
self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
| 992 |
+
classifier_dropout = (
|
| 993 |
+
self.config.classifier_dropout_prob
|
| 994 |
+
if self.config.classifier_dropout_prob is not None
|
| 995 |
+
else self.config.hidden_dropout_prob
|
| 996 |
+
)
|
| 997 |
+
self.dropout = nn.Dropout(rate=classifier_dropout)
|
| 998 |
+
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
| 999 |
+
|
| 1000 |
+
def __call__(
|
| 1001 |
+
self,
|
| 1002 |
+
input_ids,
|
| 1003 |
+
attention_mask,
|
| 1004 |
+
token_type_ids,
|
| 1005 |
+
position_ids,
|
| 1006 |
+
deterministic: bool = True,
|
| 1007 |
+
output_attentions: bool = False,
|
| 1008 |
+
output_hidden_states: bool = False,
|
| 1009 |
+
return_dict: bool = True,
|
| 1010 |
+
):
|
| 1011 |
+
# Model
|
| 1012 |
+
outputs = self.albert(
|
| 1013 |
+
input_ids,
|
| 1014 |
+
attention_mask,
|
| 1015 |
+
token_type_ids,
|
| 1016 |
+
position_ids,
|
| 1017 |
+
deterministic=deterministic,
|
| 1018 |
+
output_attentions=output_attentions,
|
| 1019 |
+
output_hidden_states=output_hidden_states,
|
| 1020 |
+
return_dict=return_dict,
|
| 1021 |
+
)
|
| 1022 |
+
|
| 1023 |
+
hidden_states = outputs[0]
|
| 1024 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 1025 |
+
logits = self.classifier(hidden_states)
|
| 1026 |
+
|
| 1027 |
+
if not return_dict:
|
| 1028 |
+
return (logits,) + outputs[1:]
|
| 1029 |
+
|
| 1030 |
+
return FlaxTokenClassifierOutput(
|
| 1031 |
+
logits=logits,
|
| 1032 |
+
hidden_states=outputs.hidden_states,
|
| 1033 |
+
attentions=outputs.attentions,
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
|
| 1037 |
+
@add_start_docstrings(
|
| 1038 |
+
"""
|
| 1039 |
+
Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1040 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1041 |
+
""",
|
| 1042 |
+
ALBERT_START_DOCSTRING,
|
| 1043 |
+
)
|
| 1044 |
+
class FlaxAlbertForTokenClassification(FlaxAlbertPreTrainedModel):
|
| 1045 |
+
module_class = FlaxAlbertForTokenClassificationModule
|
| 1046 |
+
|
| 1047 |
+
|
| 1048 |
+
append_call_sample_docstring(
|
| 1049 |
+
FlaxAlbertForTokenClassification,
|
| 1050 |
+
_CHECKPOINT_FOR_DOC,
|
| 1051 |
+
FlaxTokenClassifierOutput,
|
| 1052 |
+
_CONFIG_FOR_DOC,
|
| 1053 |
+
)
|
| 1054 |
+
|
| 1055 |
+
|
| 1056 |
+
class FlaxAlbertForQuestionAnsweringModule(nn.Module):
|
| 1057 |
+
config: AlbertConfig
|
| 1058 |
+
dtype: jnp.dtype = jnp.float32
|
| 1059 |
+
|
| 1060 |
+
def setup(self):
|
| 1061 |
+
self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
| 1062 |
+
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
| 1063 |
+
|
| 1064 |
+
def __call__(
|
| 1065 |
+
self,
|
| 1066 |
+
input_ids,
|
| 1067 |
+
attention_mask,
|
| 1068 |
+
token_type_ids,
|
| 1069 |
+
position_ids,
|
| 1070 |
+
deterministic: bool = True,
|
| 1071 |
+
output_attentions: bool = False,
|
| 1072 |
+
output_hidden_states: bool = False,
|
| 1073 |
+
return_dict: bool = True,
|
| 1074 |
+
):
|
| 1075 |
+
# Model
|
| 1076 |
+
outputs = self.albert(
|
| 1077 |
+
input_ids,
|
| 1078 |
+
attention_mask,
|
| 1079 |
+
token_type_ids,
|
| 1080 |
+
position_ids,
|
| 1081 |
+
deterministic=deterministic,
|
| 1082 |
+
output_attentions=output_attentions,
|
| 1083 |
+
output_hidden_states=output_hidden_states,
|
| 1084 |
+
return_dict=return_dict,
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
hidden_states = outputs[0]
|
| 1088 |
+
|
| 1089 |
+
logits = self.qa_outputs(hidden_states)
|
| 1090 |
+
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
|
| 1091 |
+
start_logits = start_logits.squeeze(-1)
|
| 1092 |
+
end_logits = end_logits.squeeze(-1)
|
| 1093 |
+
|
| 1094 |
+
if not return_dict:
|
| 1095 |
+
return (start_logits, end_logits) + outputs[1:]
|
| 1096 |
+
|
| 1097 |
+
return FlaxQuestionAnsweringModelOutput(
|
| 1098 |
+
start_logits=start_logits,
|
| 1099 |
+
end_logits=end_logits,
|
| 1100 |
+
hidden_states=outputs.hidden_states,
|
| 1101 |
+
attentions=outputs.attentions,
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
|
| 1105 |
+
@add_start_docstrings(
|
| 1106 |
+
"""
|
| 1107 |
+
Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1108 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1109 |
+
""",
|
| 1110 |
+
ALBERT_START_DOCSTRING,
|
| 1111 |
+
)
|
| 1112 |
+
class FlaxAlbertForQuestionAnswering(FlaxAlbertPreTrainedModel):
|
| 1113 |
+
module_class = FlaxAlbertForQuestionAnsweringModule
|
| 1114 |
+
|
| 1115 |
+
|
| 1116 |
+
append_call_sample_docstring(
|
| 1117 |
+
FlaxAlbertForQuestionAnswering,
|
| 1118 |
+
_CHECKPOINT_FOR_DOC,
|
| 1119 |
+
FlaxQuestionAnsweringModelOutput,
|
| 1120 |
+
_CONFIG_FOR_DOC,
|
| 1121 |
+
)
|
| 1122 |
+
|
| 1123 |
+
__all__ = [
|
| 1124 |
+
"FlaxAlbertPreTrainedModel",
|
| 1125 |
+
"FlaxAlbertModel",
|
| 1126 |
+
"FlaxAlbertForPreTraining",
|
| 1127 |
+
"FlaxAlbertForMaskedLM",
|
| 1128 |
+
"FlaxAlbertForSequenceClassification",
|
| 1129 |
+
"FlaxAlbertForMultipleChoice",
|
| 1130 |
+
"FlaxAlbertForTokenClassification",
|
| 1131 |
+
"FlaxAlbertForQuestionAnswering",
|
| 1132 |
+
]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/modeling_tf_albert.py
ADDED
|
@@ -0,0 +1,1572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""TF 2.0 ALBERT model."""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import math
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import tensorflow as tf
|
| 25 |
+
|
| 26 |
+
from ...activations_tf import get_tf_activation
|
| 27 |
+
from ...modeling_tf_outputs import (
|
| 28 |
+
TFBaseModelOutput,
|
| 29 |
+
TFBaseModelOutputWithPooling,
|
| 30 |
+
TFMaskedLMOutput,
|
| 31 |
+
TFMultipleChoiceModelOutput,
|
| 32 |
+
TFQuestionAnsweringModelOutput,
|
| 33 |
+
TFSequenceClassifierOutput,
|
| 34 |
+
TFTokenClassifierOutput,
|
| 35 |
+
)
|
| 36 |
+
from ...modeling_tf_utils import (
|
| 37 |
+
TFMaskedLanguageModelingLoss,
|
| 38 |
+
TFModelInputType,
|
| 39 |
+
TFMultipleChoiceLoss,
|
| 40 |
+
TFPreTrainedModel,
|
| 41 |
+
TFQuestionAnsweringLoss,
|
| 42 |
+
TFSequenceClassificationLoss,
|
| 43 |
+
TFTokenClassificationLoss,
|
| 44 |
+
get_initializer,
|
| 45 |
+
keras,
|
| 46 |
+
keras_serializable,
|
| 47 |
+
unpack_inputs,
|
| 48 |
+
)
|
| 49 |
+
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
|
| 50 |
+
from ...utils import (
|
| 51 |
+
ModelOutput,
|
| 52 |
+
add_code_sample_docstrings,
|
| 53 |
+
add_start_docstrings,
|
| 54 |
+
add_start_docstrings_to_model_forward,
|
| 55 |
+
logging,
|
| 56 |
+
replace_return_docstrings,
|
| 57 |
+
)
|
| 58 |
+
from .configuration_albert import AlbertConfig
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
logger = logging.get_logger(__name__)
|
| 62 |
+
|
| 63 |
+
_CHECKPOINT_FOR_DOC = "albert/albert-base-v2"
|
| 64 |
+
_CONFIG_FOR_DOC = "AlbertConfig"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TFAlbertPreTrainingLoss:
|
| 68 |
+
"""
|
| 69 |
+
Loss function suitable for ALBERT pretraining, that is, the task of pretraining a language model by combining SOP +
|
| 70 |
+
MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
|
| 74 |
+
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
|
| 75 |
+
if self.config.tf_legacy_loss:
|
| 76 |
+
# make sure only labels that are not equal to -100
|
| 77 |
+
# are taken into account as loss
|
| 78 |
+
masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100)
|
| 79 |
+
masked_lm_reduced_logits = tf.boolean_mask(
|
| 80 |
+
tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])),
|
| 81 |
+
mask=masked_lm_active_loss,
|
| 82 |
+
)
|
| 83 |
+
masked_lm_labels = tf.boolean_mask(
|
| 84 |
+
tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss
|
| 85 |
+
)
|
| 86 |
+
sentence_order_active_loss = tf.not_equal(
|
| 87 |
+
tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100
|
| 88 |
+
)
|
| 89 |
+
sentence_order_reduced_logits = tf.boolean_mask(
|
| 90 |
+
tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss
|
| 91 |
+
)
|
| 92 |
+
sentence_order_label = tf.boolean_mask(
|
| 93 |
+
tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss
|
| 94 |
+
)
|
| 95 |
+
masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits)
|
| 96 |
+
sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits)
|
| 97 |
+
masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0]))
|
| 98 |
+
masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0)
|
| 99 |
+
|
| 100 |
+
return masked_lm_loss + sentence_order_loss
|
| 101 |
+
|
| 102 |
+
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
|
| 103 |
+
unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
|
| 104 |
+
# make sure only labels that are not equal to -100
|
| 105 |
+
# are taken into account for the loss computation
|
| 106 |
+
lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
|
| 107 |
+
masked_lm_losses = unmasked_lm_losses * lm_loss_mask
|
| 108 |
+
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask)
|
| 109 |
+
|
| 110 |
+
sop_logits = tf.reshape(logits[1], (-1, 2))
|
| 111 |
+
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
|
| 112 |
+
unmasked_sop_loss = loss_fn(y_true=tf.nn.relu(labels["sentence_order_label"]), y_pred=sop_logits)
|
| 113 |
+
sop_loss_mask = tf.cast(labels["sentence_order_label"] != -100, dtype=unmasked_sop_loss.dtype)
|
| 114 |
+
|
| 115 |
+
masked_sop_loss = unmasked_sop_loss * sop_loss_mask
|
| 116 |
+
reduced_masked_sop_loss = tf.reduce_sum(masked_sop_loss) / tf.reduce_sum(sop_loss_mask)
|
| 117 |
+
|
| 118 |
+
return tf.reshape(reduced_masked_lm_loss + reduced_masked_sop_loss, (1,))
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class TFAlbertEmbeddings(keras.layers.Layer):
|
| 122 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 123 |
+
|
| 124 |
+
def __init__(self, config: AlbertConfig, **kwargs):
|
| 125 |
+
super().__init__(**kwargs)
|
| 126 |
+
|
| 127 |
+
self.config = config
|
| 128 |
+
self.embedding_size = config.embedding_size
|
| 129 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 130 |
+
self.initializer_range = config.initializer_range
|
| 131 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 132 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 133 |
+
|
| 134 |
+
def build(self, input_shape=None):
|
| 135 |
+
with tf.name_scope("word_embeddings"):
|
| 136 |
+
self.weight = self.add_weight(
|
| 137 |
+
name="weight",
|
| 138 |
+
shape=[self.config.vocab_size, self.embedding_size],
|
| 139 |
+
initializer=get_initializer(self.initializer_range),
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
with tf.name_scope("token_type_embeddings"):
|
| 143 |
+
self.token_type_embeddings = self.add_weight(
|
| 144 |
+
name="embeddings",
|
| 145 |
+
shape=[self.config.type_vocab_size, self.embedding_size],
|
| 146 |
+
initializer=get_initializer(self.initializer_range),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
with tf.name_scope("position_embeddings"):
|
| 150 |
+
self.position_embeddings = self.add_weight(
|
| 151 |
+
name="embeddings",
|
| 152 |
+
shape=[self.max_position_embeddings, self.embedding_size],
|
| 153 |
+
initializer=get_initializer(self.initializer_range),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if self.built:
|
| 157 |
+
return
|
| 158 |
+
self.built = True
|
| 159 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 160 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 161 |
+
self.LayerNorm.build([None, None, self.config.embedding_size])
|
| 162 |
+
|
| 163 |
+
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
|
| 164 |
+
def call(
|
| 165 |
+
self,
|
| 166 |
+
input_ids: tf.Tensor | None = None,
|
| 167 |
+
position_ids: tf.Tensor | None = None,
|
| 168 |
+
token_type_ids: tf.Tensor | None = None,
|
| 169 |
+
inputs_embeds: tf.Tensor | None = None,
|
| 170 |
+
past_key_values_length=0,
|
| 171 |
+
training: bool = False,
|
| 172 |
+
) -> tf.Tensor:
|
| 173 |
+
"""
|
| 174 |
+
Applies embedding based on inputs tensor.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
final_embeddings (`tf.Tensor`): output embedding tensor.
|
| 178 |
+
"""
|
| 179 |
+
if input_ids is None and inputs_embeds is None:
|
| 180 |
+
raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
|
| 181 |
+
|
| 182 |
+
if input_ids is not None:
|
| 183 |
+
check_embeddings_within_bounds(input_ids, self.config.vocab_size)
|
| 184 |
+
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
|
| 185 |
+
|
| 186 |
+
input_shape = shape_list(inputs_embeds)[:-1]
|
| 187 |
+
|
| 188 |
+
if token_type_ids is None:
|
| 189 |
+
token_type_ids = tf.fill(dims=input_shape, value=0)
|
| 190 |
+
|
| 191 |
+
if position_ids is None:
|
| 192 |
+
position_ids = tf.expand_dims(
|
| 193 |
+
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
| 197 |
+
token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
|
| 198 |
+
final_embeddings = inputs_embeds + position_embeds + token_type_embeds
|
| 199 |
+
final_embeddings = self.LayerNorm(inputs=final_embeddings)
|
| 200 |
+
final_embeddings = self.dropout(inputs=final_embeddings, training=training)
|
| 201 |
+
|
| 202 |
+
return final_embeddings
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class TFAlbertAttention(keras.layers.Layer):
|
| 206 |
+
"""Contains the complete attention sublayer, including both dropouts and layer norm."""
|
| 207 |
+
|
| 208 |
+
def __init__(self, config: AlbertConfig, **kwargs):
|
| 209 |
+
super().__init__(**kwargs)
|
| 210 |
+
|
| 211 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 212 |
+
raise ValueError(
|
| 213 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
|
| 214 |
+
f"of attention heads ({config.num_attention_heads})"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
self.num_attention_heads = config.num_attention_heads
|
| 218 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 219 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 220 |
+
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
|
| 221 |
+
self.output_attentions = config.output_attentions
|
| 222 |
+
|
| 223 |
+
self.query = keras.layers.Dense(
|
| 224 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
| 225 |
+
)
|
| 226 |
+
self.key = keras.layers.Dense(
|
| 227 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
|
| 228 |
+
)
|
| 229 |
+
self.value = keras.layers.Dense(
|
| 230 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
| 231 |
+
)
|
| 232 |
+
self.dense = keras.layers.Dense(
|
| 233 |
+
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 234 |
+
)
|
| 235 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 236 |
+
# Two different dropout probabilities; see https://github.com/google-research/albert/blob/master/modeling.py#L971-L993
|
| 237 |
+
self.attention_dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
| 238 |
+
self.output_dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 239 |
+
self.config = config
|
| 240 |
+
|
| 241 |
+
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
| 242 |
+
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
| 243 |
+
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
| 244 |
+
|
| 245 |
+
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
|
| 246 |
+
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
| 247 |
+
|
| 248 |
+
def call(
|
| 249 |
+
self,
|
| 250 |
+
input_tensor: tf.Tensor,
|
| 251 |
+
attention_mask: tf.Tensor,
|
| 252 |
+
head_mask: tf.Tensor,
|
| 253 |
+
output_attentions: bool,
|
| 254 |
+
training: bool = False,
|
| 255 |
+
) -> tuple[tf.Tensor]:
|
| 256 |
+
batch_size = shape_list(input_tensor)[0]
|
| 257 |
+
mixed_query_layer = self.query(inputs=input_tensor)
|
| 258 |
+
mixed_key_layer = self.key(inputs=input_tensor)
|
| 259 |
+
mixed_value_layer = self.value(inputs=input_tensor)
|
| 260 |
+
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
| 261 |
+
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
| 262 |
+
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
| 263 |
+
|
| 264 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 265 |
+
# (batch size, num_heads, seq_len_q, seq_len_k)
|
| 266 |
+
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
| 267 |
+
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
|
| 268 |
+
attention_scores = tf.divide(attention_scores, dk)
|
| 269 |
+
|
| 270 |
+
if attention_mask is not None:
|
| 271 |
+
# Apply the attention mask is (precomputed for all layers in TFAlbertModel call() function)
|
| 272 |
+
attention_scores = tf.add(attention_scores, attention_mask)
|
| 273 |
+
|
| 274 |
+
# Normalize the attention scores to probabilities.
|
| 275 |
+
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
|
| 276 |
+
|
| 277 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 278 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 279 |
+
attention_probs = self.attention_dropout(inputs=attention_probs, training=training)
|
| 280 |
+
|
| 281 |
+
# Mask heads if we want to
|
| 282 |
+
if head_mask is not None:
|
| 283 |
+
attention_probs = tf.multiply(attention_probs, head_mask)
|
| 284 |
+
|
| 285 |
+
context_layer = tf.matmul(attention_probs, value_layer)
|
| 286 |
+
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
|
| 287 |
+
|
| 288 |
+
# (batch_size, seq_len_q, all_head_size)
|
| 289 |
+
context_layer = tf.reshape(tensor=context_layer, shape=(batch_size, -1, self.all_head_size))
|
| 290 |
+
self_outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 291 |
+
hidden_states = self_outputs[0]
|
| 292 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 293 |
+
hidden_states = self.output_dropout(inputs=hidden_states, training=training)
|
| 294 |
+
attention_output = self.LayerNorm(inputs=hidden_states + input_tensor)
|
| 295 |
+
|
| 296 |
+
# add attentions if we output them
|
| 297 |
+
outputs = (attention_output,) + self_outputs[1:]
|
| 298 |
+
|
| 299 |
+
return outputs
|
| 300 |
+
|
| 301 |
+
def build(self, input_shape=None):
|
| 302 |
+
if self.built:
|
| 303 |
+
return
|
| 304 |
+
self.built = True
|
| 305 |
+
if getattr(self, "query", None) is not None:
|
| 306 |
+
with tf.name_scope(self.query.name):
|
| 307 |
+
self.query.build([None, None, self.config.hidden_size])
|
| 308 |
+
if getattr(self, "key", None) is not None:
|
| 309 |
+
with tf.name_scope(self.key.name):
|
| 310 |
+
self.key.build([None, None, self.config.hidden_size])
|
| 311 |
+
if getattr(self, "value", None) is not None:
|
| 312 |
+
with tf.name_scope(self.value.name):
|
| 313 |
+
self.value.build([None, None, self.config.hidden_size])
|
| 314 |
+
if getattr(self, "dense", None) is not None:
|
| 315 |
+
with tf.name_scope(self.dense.name):
|
| 316 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 317 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 318 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 319 |
+
self.LayerNorm.build([None, None, self.config.hidden_size])
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class TFAlbertLayer(keras.layers.Layer):
|
| 323 |
+
def __init__(self, config: AlbertConfig, **kwargs):
|
| 324 |
+
super().__init__(**kwargs)
|
| 325 |
+
|
| 326 |
+
self.attention = TFAlbertAttention(config, name="attention")
|
| 327 |
+
self.ffn = keras.layers.Dense(
|
| 328 |
+
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn"
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if isinstance(config.hidden_act, str):
|
| 332 |
+
self.activation = get_tf_activation(config.hidden_act)
|
| 333 |
+
else:
|
| 334 |
+
self.activation = config.hidden_act
|
| 335 |
+
|
| 336 |
+
self.ffn_output = keras.layers.Dense(
|
| 337 |
+
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn_output"
|
| 338 |
+
)
|
| 339 |
+
self.full_layer_layer_norm = keras.layers.LayerNormalization(
|
| 340 |
+
epsilon=config.layer_norm_eps, name="full_layer_layer_norm"
|
| 341 |
+
)
|
| 342 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 343 |
+
self.config = config
|
| 344 |
+
|
| 345 |
+
def call(
|
| 346 |
+
self,
|
| 347 |
+
hidden_states: tf.Tensor,
|
| 348 |
+
attention_mask: tf.Tensor,
|
| 349 |
+
head_mask: tf.Tensor,
|
| 350 |
+
output_attentions: bool,
|
| 351 |
+
training: bool = False,
|
| 352 |
+
) -> tuple[tf.Tensor]:
|
| 353 |
+
attention_outputs = self.attention(
|
| 354 |
+
input_tensor=hidden_states,
|
| 355 |
+
attention_mask=attention_mask,
|
| 356 |
+
head_mask=head_mask,
|
| 357 |
+
output_attentions=output_attentions,
|
| 358 |
+
training=training,
|
| 359 |
+
)
|
| 360 |
+
ffn_output = self.ffn(inputs=attention_outputs[0])
|
| 361 |
+
ffn_output = self.activation(ffn_output)
|
| 362 |
+
ffn_output = self.ffn_output(inputs=ffn_output)
|
| 363 |
+
ffn_output = self.dropout(inputs=ffn_output, training=training)
|
| 364 |
+
hidden_states = self.full_layer_layer_norm(inputs=ffn_output + attention_outputs[0])
|
| 365 |
+
|
| 366 |
+
# add attentions if we output them
|
| 367 |
+
outputs = (hidden_states,) + attention_outputs[1:]
|
| 368 |
+
|
| 369 |
+
return outputs
|
| 370 |
+
|
| 371 |
+
def build(self, input_shape=None):
|
| 372 |
+
if self.built:
|
| 373 |
+
return
|
| 374 |
+
self.built = True
|
| 375 |
+
if getattr(self, "attention", None) is not None:
|
| 376 |
+
with tf.name_scope(self.attention.name):
|
| 377 |
+
self.attention.build(None)
|
| 378 |
+
if getattr(self, "ffn", None) is not None:
|
| 379 |
+
with tf.name_scope(self.ffn.name):
|
| 380 |
+
self.ffn.build([None, None, self.config.hidden_size])
|
| 381 |
+
if getattr(self, "ffn_output", None) is not None:
|
| 382 |
+
with tf.name_scope(self.ffn_output.name):
|
| 383 |
+
self.ffn_output.build([None, None, self.config.intermediate_size])
|
| 384 |
+
if getattr(self, "full_layer_layer_norm", None) is not None:
|
| 385 |
+
with tf.name_scope(self.full_layer_layer_norm.name):
|
| 386 |
+
self.full_layer_layer_norm.build([None, None, self.config.hidden_size])
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class TFAlbertLayerGroup(keras.layers.Layer):
|
| 390 |
+
def __init__(self, config: AlbertConfig, **kwargs):
|
| 391 |
+
super().__init__(**kwargs)
|
| 392 |
+
|
| 393 |
+
self.albert_layers = [
|
| 394 |
+
TFAlbertLayer(config, name=f"albert_layers_._{i}") for i in range(config.inner_group_num)
|
| 395 |
+
]
|
| 396 |
+
|
| 397 |
+
def call(
|
| 398 |
+
self,
|
| 399 |
+
hidden_states: tf.Tensor,
|
| 400 |
+
attention_mask: tf.Tensor,
|
| 401 |
+
head_mask: tf.Tensor,
|
| 402 |
+
output_attentions: bool,
|
| 403 |
+
output_hidden_states: bool,
|
| 404 |
+
training: bool = False,
|
| 405 |
+
) -> TFBaseModelOutput | tuple[tf.Tensor]:
|
| 406 |
+
layer_hidden_states = () if output_hidden_states else None
|
| 407 |
+
layer_attentions = () if output_attentions else None
|
| 408 |
+
|
| 409 |
+
for layer_index, albert_layer in enumerate(self.albert_layers):
|
| 410 |
+
if output_hidden_states:
|
| 411 |
+
layer_hidden_states = layer_hidden_states + (hidden_states,)
|
| 412 |
+
|
| 413 |
+
layer_output = albert_layer(
|
| 414 |
+
hidden_states=hidden_states,
|
| 415 |
+
attention_mask=attention_mask,
|
| 416 |
+
head_mask=head_mask[layer_index],
|
| 417 |
+
output_attentions=output_attentions,
|
| 418 |
+
training=training,
|
| 419 |
+
)
|
| 420 |
+
hidden_states = layer_output[0]
|
| 421 |
+
|
| 422 |
+
if output_attentions:
|
| 423 |
+
layer_attentions = layer_attentions + (layer_output[1],)
|
| 424 |
+
|
| 425 |
+
# Add last layer
|
| 426 |
+
if output_hidden_states:
|
| 427 |
+
layer_hidden_states = layer_hidden_states + (hidden_states,)
|
| 428 |
+
|
| 429 |
+
return tuple(v for v in [hidden_states, layer_hidden_states, layer_attentions] if v is not None)
|
| 430 |
+
|
| 431 |
+
def build(self, input_shape=None):
|
| 432 |
+
if self.built:
|
| 433 |
+
return
|
| 434 |
+
self.built = True
|
| 435 |
+
if getattr(self, "albert_layers", None) is not None:
|
| 436 |
+
for layer in self.albert_layers:
|
| 437 |
+
with tf.name_scope(layer.name):
|
| 438 |
+
layer.build(None)
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
class TFAlbertTransformer(keras.layers.Layer):
|
| 442 |
+
def __init__(self, config: AlbertConfig, **kwargs):
|
| 443 |
+
super().__init__(**kwargs)
|
| 444 |
+
|
| 445 |
+
self.num_hidden_layers = config.num_hidden_layers
|
| 446 |
+
self.num_hidden_groups = config.num_hidden_groups
|
| 447 |
+
# Number of layers in a hidden group
|
| 448 |
+
self.layers_per_group = int(config.num_hidden_layers / config.num_hidden_groups)
|
| 449 |
+
self.embedding_hidden_mapping_in = keras.layers.Dense(
|
| 450 |
+
units=config.hidden_size,
|
| 451 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 452 |
+
name="embedding_hidden_mapping_in",
|
| 453 |
+
)
|
| 454 |
+
self.albert_layer_groups = [
|
| 455 |
+
TFAlbertLayerGroup(config, name=f"albert_layer_groups_._{i}") for i in range(config.num_hidden_groups)
|
| 456 |
+
]
|
| 457 |
+
self.config = config
|
| 458 |
+
|
| 459 |
+
def call(
|
| 460 |
+
self,
|
| 461 |
+
hidden_states: tf.Tensor,
|
| 462 |
+
attention_mask: tf.Tensor,
|
| 463 |
+
head_mask: tf.Tensor,
|
| 464 |
+
output_attentions: bool,
|
| 465 |
+
output_hidden_states: bool,
|
| 466 |
+
return_dict: bool,
|
| 467 |
+
training: bool = False,
|
| 468 |
+
) -> TFBaseModelOutput | tuple[tf.Tensor]:
|
| 469 |
+
hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states)
|
| 470 |
+
all_attentions = () if output_attentions else None
|
| 471 |
+
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
| 472 |
+
|
| 473 |
+
for i in range(self.num_hidden_layers):
|
| 474 |
+
# Index of the hidden group
|
| 475 |
+
group_idx = int(i / (self.num_hidden_layers / self.num_hidden_groups))
|
| 476 |
+
layer_group_output = self.albert_layer_groups[group_idx](
|
| 477 |
+
hidden_states=hidden_states,
|
| 478 |
+
attention_mask=attention_mask,
|
| 479 |
+
head_mask=head_mask[group_idx * self.layers_per_group : (group_idx + 1) * self.layers_per_group],
|
| 480 |
+
output_attentions=output_attentions,
|
| 481 |
+
output_hidden_states=output_hidden_states,
|
| 482 |
+
training=training,
|
| 483 |
+
)
|
| 484 |
+
hidden_states = layer_group_output[0]
|
| 485 |
+
|
| 486 |
+
if output_attentions:
|
| 487 |
+
all_attentions = all_attentions + layer_group_output[-1]
|
| 488 |
+
|
| 489 |
+
if output_hidden_states:
|
| 490 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 491 |
+
|
| 492 |
+
if not return_dict:
|
| 493 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
| 494 |
+
|
| 495 |
+
return TFBaseModelOutput(
|
| 496 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
def build(self, input_shape=None):
|
| 500 |
+
if self.built:
|
| 501 |
+
return
|
| 502 |
+
self.built = True
|
| 503 |
+
if getattr(self, "embedding_hidden_mapping_in", None) is not None:
|
| 504 |
+
with tf.name_scope(self.embedding_hidden_mapping_in.name):
|
| 505 |
+
self.embedding_hidden_mapping_in.build([None, None, self.config.embedding_size])
|
| 506 |
+
if getattr(self, "albert_layer_groups", None) is not None:
|
| 507 |
+
for layer in self.albert_layer_groups:
|
| 508 |
+
with tf.name_scope(layer.name):
|
| 509 |
+
layer.build(None)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
class TFAlbertPreTrainedModel(TFPreTrainedModel):
|
| 513 |
+
"""
|
| 514 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 515 |
+
models.
|
| 516 |
+
"""
|
| 517 |
+
|
| 518 |
+
config_class = AlbertConfig
|
| 519 |
+
base_model_prefix = "albert"
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
class TFAlbertMLMHead(keras.layers.Layer):
|
| 523 |
+
def __init__(self, config: AlbertConfig, input_embeddings: keras.layers.Layer, **kwargs):
|
| 524 |
+
super().__init__(**kwargs)
|
| 525 |
+
|
| 526 |
+
self.config = config
|
| 527 |
+
self.embedding_size = config.embedding_size
|
| 528 |
+
self.dense = keras.layers.Dense(
|
| 529 |
+
config.embedding_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 530 |
+
)
|
| 531 |
+
if isinstance(config.hidden_act, str):
|
| 532 |
+
self.activation = get_tf_activation(config.hidden_act)
|
| 533 |
+
else:
|
| 534 |
+
self.activation = config.hidden_act
|
| 535 |
+
|
| 536 |
+
self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
| 537 |
+
|
| 538 |
+
# The output weights are the same as the input embeddings, but there is
|
| 539 |
+
# an output-only bias for each token.
|
| 540 |
+
self.decoder = input_embeddings
|
| 541 |
+
|
| 542 |
+
def build(self, input_shape=None):
|
| 543 |
+
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
|
| 544 |
+
self.decoder_bias = self.add_weight(
|
| 545 |
+
shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias"
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
if self.built:
|
| 549 |
+
return
|
| 550 |
+
self.built = True
|
| 551 |
+
if getattr(self, "dense", None) is not None:
|
| 552 |
+
with tf.name_scope(self.dense.name):
|
| 553 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 554 |
+
if getattr(self, "LayerNorm", None) is not None:
|
| 555 |
+
with tf.name_scope(self.LayerNorm.name):
|
| 556 |
+
self.LayerNorm.build([None, None, self.config.embedding_size])
|
| 557 |
+
|
| 558 |
+
def get_output_embeddings(self) -> keras.layers.Layer:
|
| 559 |
+
return self.decoder
|
| 560 |
+
|
| 561 |
+
def set_output_embeddings(self, value: tf.Variable):
|
| 562 |
+
self.decoder.weight = value
|
| 563 |
+
self.decoder.vocab_size = shape_list(value)[0]
|
| 564 |
+
|
| 565 |
+
def get_bias(self) -> dict[str, tf.Variable]:
|
| 566 |
+
return {"bias": self.bias, "decoder_bias": self.decoder_bias}
|
| 567 |
+
|
| 568 |
+
def set_bias(self, value: tf.Variable):
|
| 569 |
+
self.bias = value["bias"]
|
| 570 |
+
self.decoder_bias = value["decoder_bias"]
|
| 571 |
+
self.config.vocab_size = shape_list(value["bias"])[0]
|
| 572 |
+
|
| 573 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 574 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 575 |
+
hidden_states = self.activation(hidden_states)
|
| 576 |
+
hidden_states = self.LayerNorm(inputs=hidden_states)
|
| 577 |
+
seq_length = shape_list(tensor=hidden_states)[1]
|
| 578 |
+
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
|
| 579 |
+
hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True)
|
| 580 |
+
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
|
| 581 |
+
hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.decoder_bias)
|
| 582 |
+
|
| 583 |
+
return hidden_states
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
@keras_serializable
|
| 587 |
+
class TFAlbertMainLayer(keras.layers.Layer):
|
| 588 |
+
config_class = AlbertConfig
|
| 589 |
+
|
| 590 |
+
def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True, **kwargs):
|
| 591 |
+
super().__init__(**kwargs)
|
| 592 |
+
|
| 593 |
+
self.config = config
|
| 594 |
+
|
| 595 |
+
self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
|
| 596 |
+
self.encoder = TFAlbertTransformer(config, name="encoder")
|
| 597 |
+
self.pooler = (
|
| 598 |
+
keras.layers.Dense(
|
| 599 |
+
units=config.hidden_size,
|
| 600 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 601 |
+
activation="tanh",
|
| 602 |
+
name="pooler",
|
| 603 |
+
)
|
| 604 |
+
if add_pooling_layer
|
| 605 |
+
else None
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
def get_input_embeddings(self) -> keras.layers.Layer:
|
| 609 |
+
return self.embeddings
|
| 610 |
+
|
| 611 |
+
def set_input_embeddings(self, value: tf.Variable):
|
| 612 |
+
self.embeddings.weight = value
|
| 613 |
+
self.embeddings.vocab_size = shape_list(value)[0]
|
| 614 |
+
|
| 615 |
+
def _prune_heads(self, heads_to_prune):
|
| 616 |
+
"""
|
| 617 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 618 |
+
class PreTrainedModel
|
| 619 |
+
"""
|
| 620 |
+
raise NotImplementedError
|
| 621 |
+
|
| 622 |
+
@unpack_inputs
|
| 623 |
+
def call(
|
| 624 |
+
self,
|
| 625 |
+
input_ids: TFModelInputType | None = None,
|
| 626 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 627 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 628 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 629 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 630 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 631 |
+
output_attentions: bool | None = None,
|
| 632 |
+
output_hidden_states: bool | None = None,
|
| 633 |
+
return_dict: bool | None = None,
|
| 634 |
+
training: bool = False,
|
| 635 |
+
) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
|
| 636 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 637 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 638 |
+
elif input_ids is not None:
|
| 639 |
+
input_shape = shape_list(input_ids)
|
| 640 |
+
elif inputs_embeds is not None:
|
| 641 |
+
input_shape = shape_list(inputs_embeds)[:-1]
|
| 642 |
+
else:
|
| 643 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 644 |
+
|
| 645 |
+
if attention_mask is None:
|
| 646 |
+
attention_mask = tf.fill(dims=input_shape, value=1)
|
| 647 |
+
|
| 648 |
+
if token_type_ids is None:
|
| 649 |
+
token_type_ids = tf.fill(dims=input_shape, value=0)
|
| 650 |
+
|
| 651 |
+
embedding_output = self.embeddings(
|
| 652 |
+
input_ids=input_ids,
|
| 653 |
+
position_ids=position_ids,
|
| 654 |
+
token_type_ids=token_type_ids,
|
| 655 |
+
inputs_embeds=inputs_embeds,
|
| 656 |
+
training=training,
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 660 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 661 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 662 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 663 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 664 |
+
extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
|
| 665 |
+
|
| 666 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 667 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 668 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 669 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 670 |
+
# effectively the same as removing these entirely.
|
| 671 |
+
extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
|
| 672 |
+
one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
|
| 673 |
+
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
| 674 |
+
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
| 675 |
+
|
| 676 |
+
# Prepare head mask if needed
|
| 677 |
+
# 1.0 in head_mask indicate we keep the head
|
| 678 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 679 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 680 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 681 |
+
if head_mask is not None:
|
| 682 |
+
raise NotImplementedError
|
| 683 |
+
else:
|
| 684 |
+
head_mask = [None] * self.config.num_hidden_layers
|
| 685 |
+
|
| 686 |
+
encoder_outputs = self.encoder(
|
| 687 |
+
hidden_states=embedding_output,
|
| 688 |
+
attention_mask=extended_attention_mask,
|
| 689 |
+
head_mask=head_mask,
|
| 690 |
+
output_attentions=output_attentions,
|
| 691 |
+
output_hidden_states=output_hidden_states,
|
| 692 |
+
return_dict=return_dict,
|
| 693 |
+
training=training,
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
sequence_output = encoder_outputs[0]
|
| 697 |
+
pooled_output = self.pooler(inputs=sequence_output[:, 0]) if self.pooler is not None else None
|
| 698 |
+
|
| 699 |
+
if not return_dict:
|
| 700 |
+
return (
|
| 701 |
+
sequence_output,
|
| 702 |
+
pooled_output,
|
| 703 |
+
) + encoder_outputs[1:]
|
| 704 |
+
|
| 705 |
+
return TFBaseModelOutputWithPooling(
|
| 706 |
+
last_hidden_state=sequence_output,
|
| 707 |
+
pooler_output=pooled_output,
|
| 708 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 709 |
+
attentions=encoder_outputs.attentions,
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
def build(self, input_shape=None):
|
| 713 |
+
if self.built:
|
| 714 |
+
return
|
| 715 |
+
self.built = True
|
| 716 |
+
if getattr(self, "embeddings", None) is not None:
|
| 717 |
+
with tf.name_scope(self.embeddings.name):
|
| 718 |
+
self.embeddings.build(None)
|
| 719 |
+
if getattr(self, "encoder", None) is not None:
|
| 720 |
+
with tf.name_scope(self.encoder.name):
|
| 721 |
+
self.encoder.build(None)
|
| 722 |
+
if getattr(self, "pooler", None) is not None:
|
| 723 |
+
with tf.name_scope(self.pooler.name):
|
| 724 |
+
self.pooler.build([None, None, self.config.hidden_size])
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
@dataclass
|
| 728 |
+
class TFAlbertForPreTrainingOutput(ModelOutput):
|
| 729 |
+
"""
|
| 730 |
+
Output type of [`TFAlbertForPreTraining`].
|
| 731 |
+
|
| 732 |
+
Args:
|
| 733 |
+
prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 734 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 735 |
+
sop_logits (`tf.Tensor` of shape `(batch_size, 2)`):
|
| 736 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
| 737 |
+
before SoftMax).
|
| 738 |
+
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 739 |
+
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 740 |
+
`(batch_size, sequence_length, hidden_size)`.
|
| 741 |
+
|
| 742 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 743 |
+
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 744 |
+
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 745 |
+
sequence_length)`.
|
| 746 |
+
|
| 747 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 748 |
+
heads.
|
| 749 |
+
"""
|
| 750 |
+
|
| 751 |
+
loss: tf.Tensor | None = None
|
| 752 |
+
prediction_logits: tf.Tensor | None = None
|
| 753 |
+
sop_logits: tf.Tensor | None = None
|
| 754 |
+
hidden_states: tuple[tf.Tensor] | None = None
|
| 755 |
+
attentions: tuple[tf.Tensor] | None = None
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
ALBERT_START_DOCSTRING = r"""
|
| 759 |
+
|
| 760 |
+
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 761 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 762 |
+
etc.)
|
| 763 |
+
|
| 764 |
+
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
| 765 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
| 766 |
+
behavior.
|
| 767 |
+
|
| 768 |
+
<Tip>
|
| 769 |
+
|
| 770 |
+
TensorFlow models and layers in `transformers` accept two formats as input:
|
| 771 |
+
|
| 772 |
+
- having all inputs as keyword arguments (like PyTorch models), or
|
| 773 |
+
- having all inputs as a list, tuple or dict in the first positional argument.
|
| 774 |
+
|
| 775 |
+
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
|
| 776 |
+
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
|
| 777 |
+
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
|
| 778 |
+
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
|
| 779 |
+
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
|
| 780 |
+
positional argument:
|
| 781 |
+
|
| 782 |
+
- a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
|
| 783 |
+
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
| 784 |
+
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
|
| 785 |
+
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
| 786 |
+
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
|
| 787 |
+
|
| 788 |
+
Note that when creating models and layers with
|
| 789 |
+
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
|
| 790 |
+
about any of this, as you can just pass inputs like you would to any other Python function!
|
| 791 |
+
|
| 792 |
+
</Tip>
|
| 793 |
+
|
| 794 |
+
Args:
|
| 795 |
+
config ([`AlbertConfig`]): Model configuration class with all the parameters of the model.
|
| 796 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 797 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 798 |
+
"""
|
| 799 |
+
|
| 800 |
+
ALBERT_INPUTS_DOCSTRING = r"""
|
| 801 |
+
Args:
|
| 802 |
+
input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
|
| 803 |
+
Indices of input sequence tokens in the vocabulary.
|
| 804 |
+
|
| 805 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
|
| 806 |
+
[`PreTrainedTokenizer.encode`] for details.
|
| 807 |
+
|
| 808 |
+
[What are input IDs?](../glossary#input-ids)
|
| 809 |
+
attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 810 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 811 |
+
|
| 812 |
+
- 1 for tokens that are **not masked**,
|
| 813 |
+
- 0 for tokens that are **masked**.
|
| 814 |
+
|
| 815 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 816 |
+
token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 817 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 818 |
+
1]`:
|
| 819 |
+
|
| 820 |
+
- 0 corresponds to a *sentence A* token,
|
| 821 |
+
- 1 corresponds to a *sentence B* token.
|
| 822 |
+
|
| 823 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 824 |
+
position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
|
| 825 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 826 |
+
config.max_position_embeddings - 1]`.
|
| 827 |
+
|
| 828 |
+
[What are position IDs?](../glossary#position-ids)
|
| 829 |
+
head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 830 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 831 |
+
|
| 832 |
+
- 1 indicates the head is **not masked**,
|
| 833 |
+
- 0 indicates the head is **masked**.
|
| 834 |
+
|
| 835 |
+
inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
|
| 836 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 837 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 838 |
+
model's internal embedding lookup matrix.
|
| 839 |
+
output_attentions (`bool`, *optional*):
|
| 840 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 841 |
+
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
|
| 842 |
+
config will be used instead.
|
| 843 |
+
output_hidden_states (`bool`, *optional*):
|
| 844 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 845 |
+
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
| 846 |
+
used instead.
|
| 847 |
+
return_dict (`bool`, *optional*):
|
| 848 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
|
| 849 |
+
eager mode, in graph mode the value will always be set to True.
|
| 850 |
+
training (`bool`, *optional*, defaults to `False`):
|
| 851 |
+
Whether or not to use the model in training mode (some modules like dropout modules have different
|
| 852 |
+
behaviors between training and evaluation).
|
| 853 |
+
"""
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
@add_start_docstrings(
|
| 857 |
+
"The bare Albert Model transformer outputting raw hidden-states without any specific head on top.",
|
| 858 |
+
ALBERT_START_DOCSTRING,
|
| 859 |
+
)
|
| 860 |
+
class TFAlbertModel(TFAlbertPreTrainedModel):
|
| 861 |
+
def __init__(self, config: AlbertConfig, *inputs, **kwargs):
|
| 862 |
+
super().__init__(config, *inputs, **kwargs)
|
| 863 |
+
|
| 864 |
+
self.albert = TFAlbertMainLayer(config, name="albert")
|
| 865 |
+
|
| 866 |
+
@unpack_inputs
|
| 867 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 868 |
+
@add_code_sample_docstrings(
|
| 869 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 870 |
+
output_type=TFBaseModelOutputWithPooling,
|
| 871 |
+
config_class=_CONFIG_FOR_DOC,
|
| 872 |
+
)
|
| 873 |
+
def call(
|
| 874 |
+
self,
|
| 875 |
+
input_ids: TFModelInputType | None = None,
|
| 876 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 877 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 878 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 879 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 880 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 881 |
+
output_attentions: bool | None = None,
|
| 882 |
+
output_hidden_states: bool | None = None,
|
| 883 |
+
return_dict: bool | None = None,
|
| 884 |
+
training: bool | None = False,
|
| 885 |
+
) -> TFBaseModelOutputWithPooling | tuple[tf.Tensor]:
|
| 886 |
+
outputs = self.albert(
|
| 887 |
+
input_ids=input_ids,
|
| 888 |
+
attention_mask=attention_mask,
|
| 889 |
+
token_type_ids=token_type_ids,
|
| 890 |
+
position_ids=position_ids,
|
| 891 |
+
head_mask=head_mask,
|
| 892 |
+
inputs_embeds=inputs_embeds,
|
| 893 |
+
output_attentions=output_attentions,
|
| 894 |
+
output_hidden_states=output_hidden_states,
|
| 895 |
+
return_dict=return_dict,
|
| 896 |
+
training=training,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
return outputs
|
| 900 |
+
|
| 901 |
+
def build(self, input_shape=None):
|
| 902 |
+
if self.built:
|
| 903 |
+
return
|
| 904 |
+
self.built = True
|
| 905 |
+
if getattr(self, "albert", None) is not None:
|
| 906 |
+
with tf.name_scope(self.albert.name):
|
| 907 |
+
self.albert.build(None)
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
@add_start_docstrings(
|
| 911 |
+
"""
|
| 912 |
+
Albert Model with two heads on top for pretraining: a `masked language modeling` head and a `sentence order
|
| 913 |
+
prediction` (classification) head.
|
| 914 |
+
""",
|
| 915 |
+
ALBERT_START_DOCSTRING,
|
| 916 |
+
)
|
| 917 |
+
class TFAlbertForPreTraining(TFAlbertPreTrainedModel, TFAlbertPreTrainingLoss):
|
| 918 |
+
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
| 919 |
+
_keys_to_ignore_on_load_unexpected = [r"predictions.decoder.weight"]
|
| 920 |
+
|
| 921 |
+
def __init__(self, config: AlbertConfig, *inputs, **kwargs):
|
| 922 |
+
super().__init__(config, *inputs, **kwargs)
|
| 923 |
+
|
| 924 |
+
self.num_labels = config.num_labels
|
| 925 |
+
|
| 926 |
+
self.albert = TFAlbertMainLayer(config, name="albert")
|
| 927 |
+
self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions")
|
| 928 |
+
self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier")
|
| 929 |
+
|
| 930 |
+
def get_lm_head(self) -> keras.layers.Layer:
|
| 931 |
+
return self.predictions
|
| 932 |
+
|
| 933 |
+
@unpack_inputs
|
| 934 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 935 |
+
@replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
| 936 |
+
def call(
|
| 937 |
+
self,
|
| 938 |
+
input_ids: TFModelInputType | None = None,
|
| 939 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 940 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 941 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 942 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 943 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 944 |
+
output_attentions: bool | None = None,
|
| 945 |
+
output_hidden_states: bool | None = None,
|
| 946 |
+
return_dict: bool | None = None,
|
| 947 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 948 |
+
sentence_order_label: np.ndarray | tf.Tensor | None = None,
|
| 949 |
+
training: bool | None = False,
|
| 950 |
+
) -> TFAlbertForPreTrainingOutput | tuple[tf.Tensor]:
|
| 951 |
+
r"""
|
| 952 |
+
Return:
|
| 953 |
+
|
| 954 |
+
Example:
|
| 955 |
+
|
| 956 |
+
```python
|
| 957 |
+
>>> import tensorflow as tf
|
| 958 |
+
>>> from transformers import AutoTokenizer, TFAlbertForPreTraining
|
| 959 |
+
|
| 960 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
|
| 961 |
+
>>> model = TFAlbertForPreTraining.from_pretrained("albert/albert-base-v2")
|
| 962 |
+
|
| 963 |
+
>>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]
|
| 964 |
+
>>> # Batch size 1
|
| 965 |
+
>>> outputs = model(input_ids)
|
| 966 |
+
|
| 967 |
+
>>> prediction_logits = outputs.prediction_logits
|
| 968 |
+
>>> sop_logits = outputs.sop_logits
|
| 969 |
+
```"""
|
| 970 |
+
|
| 971 |
+
outputs = self.albert(
|
| 972 |
+
input_ids=input_ids,
|
| 973 |
+
attention_mask=attention_mask,
|
| 974 |
+
token_type_ids=token_type_ids,
|
| 975 |
+
position_ids=position_ids,
|
| 976 |
+
head_mask=head_mask,
|
| 977 |
+
inputs_embeds=inputs_embeds,
|
| 978 |
+
output_attentions=output_attentions,
|
| 979 |
+
output_hidden_states=output_hidden_states,
|
| 980 |
+
return_dict=return_dict,
|
| 981 |
+
training=training,
|
| 982 |
+
)
|
| 983 |
+
sequence_output, pooled_output = outputs[:2]
|
| 984 |
+
prediction_scores = self.predictions(hidden_states=sequence_output)
|
| 985 |
+
sop_scores = self.sop_classifier(pooled_output=pooled_output, training=training)
|
| 986 |
+
total_loss = None
|
| 987 |
+
|
| 988 |
+
if labels is not None and sentence_order_label is not None:
|
| 989 |
+
d_labels = {"labels": labels}
|
| 990 |
+
d_labels["sentence_order_label"] = sentence_order_label
|
| 991 |
+
total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, sop_scores))
|
| 992 |
+
|
| 993 |
+
if not return_dict:
|
| 994 |
+
output = (prediction_scores, sop_scores) + outputs[2:]
|
| 995 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 996 |
+
|
| 997 |
+
return TFAlbertForPreTrainingOutput(
|
| 998 |
+
loss=total_loss,
|
| 999 |
+
prediction_logits=prediction_scores,
|
| 1000 |
+
sop_logits=sop_scores,
|
| 1001 |
+
hidden_states=outputs.hidden_states,
|
| 1002 |
+
attentions=outputs.attentions,
|
| 1003 |
+
)
|
| 1004 |
+
|
| 1005 |
+
def build(self, input_shape=None):
|
| 1006 |
+
if self.built:
|
| 1007 |
+
return
|
| 1008 |
+
self.built = True
|
| 1009 |
+
if getattr(self, "albert", None) is not None:
|
| 1010 |
+
with tf.name_scope(self.albert.name):
|
| 1011 |
+
self.albert.build(None)
|
| 1012 |
+
if getattr(self, "predictions", None) is not None:
|
| 1013 |
+
with tf.name_scope(self.predictions.name):
|
| 1014 |
+
self.predictions.build(None)
|
| 1015 |
+
if getattr(self, "sop_classifier", None) is not None:
|
| 1016 |
+
with tf.name_scope(self.sop_classifier.name):
|
| 1017 |
+
self.sop_classifier.build(None)
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
class TFAlbertSOPHead(keras.layers.Layer):
|
| 1021 |
+
def __init__(self, config: AlbertConfig, **kwargs):
|
| 1022 |
+
super().__init__(**kwargs)
|
| 1023 |
+
|
| 1024 |
+
self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob)
|
| 1025 |
+
self.classifier = keras.layers.Dense(
|
| 1026 |
+
units=config.num_labels,
|
| 1027 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 1028 |
+
name="classifier",
|
| 1029 |
+
)
|
| 1030 |
+
self.config = config
|
| 1031 |
+
|
| 1032 |
+
def call(self, pooled_output: tf.Tensor, training: bool) -> tf.Tensor:
|
| 1033 |
+
dropout_pooled_output = self.dropout(inputs=pooled_output, training=training)
|
| 1034 |
+
logits = self.classifier(inputs=dropout_pooled_output)
|
| 1035 |
+
|
| 1036 |
+
return logits
|
| 1037 |
+
|
| 1038 |
+
def build(self, input_shape=None):
|
| 1039 |
+
if self.built:
|
| 1040 |
+
return
|
| 1041 |
+
self.built = True
|
| 1042 |
+
if getattr(self, "classifier", None) is not None:
|
| 1043 |
+
with tf.name_scope(self.classifier.name):
|
| 1044 |
+
self.classifier.build([None, None, self.config.hidden_size])
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
@add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING)
|
| 1048 |
+
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
| 1049 |
+
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
| 1050 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions.decoder.weight"]
|
| 1051 |
+
|
| 1052 |
+
def __init__(self, config: AlbertConfig, *inputs, **kwargs):
|
| 1053 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1054 |
+
|
| 1055 |
+
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
|
| 1056 |
+
self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions")
|
| 1057 |
+
|
| 1058 |
+
def get_lm_head(self) -> keras.layers.Layer:
|
| 1059 |
+
return self.predictions
|
| 1060 |
+
|
| 1061 |
+
@unpack_inputs
|
| 1062 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1063 |
+
@replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
| 1064 |
+
def call(
|
| 1065 |
+
self,
|
| 1066 |
+
input_ids: TFModelInputType | None = None,
|
| 1067 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1068 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1069 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1070 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1071 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1072 |
+
output_attentions: bool | None = None,
|
| 1073 |
+
output_hidden_states: bool | None = None,
|
| 1074 |
+
return_dict: bool | None = None,
|
| 1075 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1076 |
+
training: bool | None = False,
|
| 1077 |
+
) -> TFMaskedLMOutput | tuple[tf.Tensor]:
|
| 1078 |
+
r"""
|
| 1079 |
+
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1080 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 1081 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 1082 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 1083 |
+
|
| 1084 |
+
Returns:
|
| 1085 |
+
|
| 1086 |
+
Example:
|
| 1087 |
+
|
| 1088 |
+
```python
|
| 1089 |
+
>>> import tensorflow as tf
|
| 1090 |
+
>>> from transformers import AutoTokenizer, TFAlbertForMaskedLM
|
| 1091 |
+
|
| 1092 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
|
| 1093 |
+
>>> model = TFAlbertForMaskedLM.from_pretrained("albert/albert-base-v2")
|
| 1094 |
+
|
| 1095 |
+
>>> # add mask_token
|
| 1096 |
+
>>> inputs = tokenizer(f"The capital of [MASK] is Paris.", return_tensors="tf")
|
| 1097 |
+
>>> logits = model(**inputs).logits
|
| 1098 |
+
|
| 1099 |
+
>>> # retrieve index of [MASK]
|
| 1100 |
+
>>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1]
|
| 1101 |
+
>>> predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1)
|
| 1102 |
+
>>> tokenizer.decode(predicted_token_id)
|
| 1103 |
+
'france'
|
| 1104 |
+
```
|
| 1105 |
+
|
| 1106 |
+
```python
|
| 1107 |
+
>>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
|
| 1108 |
+
>>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
|
| 1109 |
+
>>> outputs = model(**inputs, labels=labels)
|
| 1110 |
+
>>> round(float(outputs.loss), 2)
|
| 1111 |
+
0.81
|
| 1112 |
+
```
|
| 1113 |
+
"""
|
| 1114 |
+
outputs = self.albert(
|
| 1115 |
+
input_ids=input_ids,
|
| 1116 |
+
attention_mask=attention_mask,
|
| 1117 |
+
token_type_ids=token_type_ids,
|
| 1118 |
+
position_ids=position_ids,
|
| 1119 |
+
head_mask=head_mask,
|
| 1120 |
+
inputs_embeds=inputs_embeds,
|
| 1121 |
+
output_attentions=output_attentions,
|
| 1122 |
+
output_hidden_states=output_hidden_states,
|
| 1123 |
+
return_dict=return_dict,
|
| 1124 |
+
training=training,
|
| 1125 |
+
)
|
| 1126 |
+
sequence_output = outputs[0]
|
| 1127 |
+
prediction_scores = self.predictions(hidden_states=sequence_output, training=training)
|
| 1128 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
|
| 1129 |
+
|
| 1130 |
+
if not return_dict:
|
| 1131 |
+
output = (prediction_scores,) + outputs[2:]
|
| 1132 |
+
|
| 1133 |
+
return ((loss,) + output) if loss is not None else output
|
| 1134 |
+
|
| 1135 |
+
return TFMaskedLMOutput(
|
| 1136 |
+
loss=loss,
|
| 1137 |
+
logits=prediction_scores,
|
| 1138 |
+
hidden_states=outputs.hidden_states,
|
| 1139 |
+
attentions=outputs.attentions,
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
def build(self, input_shape=None):
|
| 1143 |
+
if self.built:
|
| 1144 |
+
return
|
| 1145 |
+
self.built = True
|
| 1146 |
+
if getattr(self, "albert", None) is not None:
|
| 1147 |
+
with tf.name_scope(self.albert.name):
|
| 1148 |
+
self.albert.build(None)
|
| 1149 |
+
if getattr(self, "predictions", None) is not None:
|
| 1150 |
+
with tf.name_scope(self.predictions.name):
|
| 1151 |
+
self.predictions.build(None)
|
| 1152 |
+
|
| 1153 |
+
|
| 1154 |
+
@add_start_docstrings(
|
| 1155 |
+
"""
|
| 1156 |
+
Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
| 1157 |
+
output) e.g. for GLUE tasks.
|
| 1158 |
+
""",
|
| 1159 |
+
ALBERT_START_DOCSTRING,
|
| 1160 |
+
)
|
| 1161 |
+
class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClassificationLoss):
|
| 1162 |
+
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
| 1163 |
+
_keys_to_ignore_on_load_unexpected = [r"predictions"]
|
| 1164 |
+
_keys_to_ignore_on_load_missing = [r"dropout"]
|
| 1165 |
+
|
| 1166 |
+
def __init__(self, config: AlbertConfig, *inputs, **kwargs):
|
| 1167 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1168 |
+
|
| 1169 |
+
self.num_labels = config.num_labels
|
| 1170 |
+
|
| 1171 |
+
self.albert = TFAlbertMainLayer(config, name="albert")
|
| 1172 |
+
self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob)
|
| 1173 |
+
self.classifier = keras.layers.Dense(
|
| 1174 |
+
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
| 1175 |
+
)
|
| 1176 |
+
self.config = config
|
| 1177 |
+
|
| 1178 |
+
@unpack_inputs
|
| 1179 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1180 |
+
@add_code_sample_docstrings(
|
| 1181 |
+
checkpoint="vumichien/albert-base-v2-imdb",
|
| 1182 |
+
output_type=TFSequenceClassifierOutput,
|
| 1183 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1184 |
+
expected_output="'LABEL_1'",
|
| 1185 |
+
expected_loss=0.12,
|
| 1186 |
+
)
|
| 1187 |
+
def call(
|
| 1188 |
+
self,
|
| 1189 |
+
input_ids: TFModelInputType | None = None,
|
| 1190 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1191 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1192 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1193 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1194 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1195 |
+
output_attentions: bool | None = None,
|
| 1196 |
+
output_hidden_states: bool | None = None,
|
| 1197 |
+
return_dict: bool | None = None,
|
| 1198 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1199 |
+
training: bool | None = False,
|
| 1200 |
+
) -> TFSequenceClassifierOutput | tuple[tf.Tensor]:
|
| 1201 |
+
r"""
|
| 1202 |
+
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
| 1203 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1204 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1205 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1206 |
+
"""
|
| 1207 |
+
outputs = self.albert(
|
| 1208 |
+
input_ids=input_ids,
|
| 1209 |
+
attention_mask=attention_mask,
|
| 1210 |
+
token_type_ids=token_type_ids,
|
| 1211 |
+
position_ids=position_ids,
|
| 1212 |
+
head_mask=head_mask,
|
| 1213 |
+
inputs_embeds=inputs_embeds,
|
| 1214 |
+
output_attentions=output_attentions,
|
| 1215 |
+
output_hidden_states=output_hidden_states,
|
| 1216 |
+
return_dict=return_dict,
|
| 1217 |
+
training=training,
|
| 1218 |
+
)
|
| 1219 |
+
pooled_output = outputs[1]
|
| 1220 |
+
pooled_output = self.dropout(inputs=pooled_output, training=training)
|
| 1221 |
+
logits = self.classifier(inputs=pooled_output)
|
| 1222 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
| 1223 |
+
|
| 1224 |
+
if not return_dict:
|
| 1225 |
+
output = (logits,) + outputs[2:]
|
| 1226 |
+
|
| 1227 |
+
return ((loss,) + output) if loss is not None else output
|
| 1228 |
+
|
| 1229 |
+
return TFSequenceClassifierOutput(
|
| 1230 |
+
loss=loss,
|
| 1231 |
+
logits=logits,
|
| 1232 |
+
hidden_states=outputs.hidden_states,
|
| 1233 |
+
attentions=outputs.attentions,
|
| 1234 |
+
)
|
| 1235 |
+
|
| 1236 |
+
def build(self, input_shape=None):
|
| 1237 |
+
if self.built:
|
| 1238 |
+
return
|
| 1239 |
+
self.built = True
|
| 1240 |
+
if getattr(self, "albert", None) is not None:
|
| 1241 |
+
with tf.name_scope(self.albert.name):
|
| 1242 |
+
self.albert.build(None)
|
| 1243 |
+
if getattr(self, "classifier", None) is not None:
|
| 1244 |
+
with tf.name_scope(self.classifier.name):
|
| 1245 |
+
self.classifier.build([None, None, self.config.hidden_size])
|
| 1246 |
+
|
| 1247 |
+
|
| 1248 |
+
@add_start_docstrings(
|
| 1249 |
+
"""
|
| 1250 |
+
Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1251 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1252 |
+
""",
|
| 1253 |
+
ALBERT_START_DOCSTRING,
|
| 1254 |
+
)
|
| 1255 |
+
class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
|
| 1256 |
+
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
| 1257 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
|
| 1258 |
+
_keys_to_ignore_on_load_missing = [r"dropout"]
|
| 1259 |
+
|
| 1260 |
+
def __init__(self, config: AlbertConfig, *inputs, **kwargs):
|
| 1261 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1262 |
+
|
| 1263 |
+
self.num_labels = config.num_labels
|
| 1264 |
+
|
| 1265 |
+
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
|
| 1266 |
+
classifier_dropout_prob = (
|
| 1267 |
+
config.classifier_dropout_prob
|
| 1268 |
+
if config.classifier_dropout_prob is not None
|
| 1269 |
+
else config.hidden_dropout_prob
|
| 1270 |
+
)
|
| 1271 |
+
self.dropout = keras.layers.Dropout(rate=classifier_dropout_prob)
|
| 1272 |
+
self.classifier = keras.layers.Dense(
|
| 1273 |
+
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
| 1274 |
+
)
|
| 1275 |
+
self.config = config
|
| 1276 |
+
|
| 1277 |
+
@unpack_inputs
|
| 1278 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1279 |
+
@add_code_sample_docstrings(
|
| 1280 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1281 |
+
output_type=TFTokenClassifierOutput,
|
| 1282 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1283 |
+
)
|
| 1284 |
+
def call(
|
| 1285 |
+
self,
|
| 1286 |
+
input_ids: TFModelInputType | None = None,
|
| 1287 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1288 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1289 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1290 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1291 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1292 |
+
output_attentions: bool | None = None,
|
| 1293 |
+
output_hidden_states: bool | None = None,
|
| 1294 |
+
return_dict: bool | None = None,
|
| 1295 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1296 |
+
training: bool | None = False,
|
| 1297 |
+
) -> TFTokenClassifierOutput | tuple[tf.Tensor]:
|
| 1298 |
+
r"""
|
| 1299 |
+
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1300 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 1301 |
+
"""
|
| 1302 |
+
outputs = self.albert(
|
| 1303 |
+
input_ids=input_ids,
|
| 1304 |
+
attention_mask=attention_mask,
|
| 1305 |
+
token_type_ids=token_type_ids,
|
| 1306 |
+
position_ids=position_ids,
|
| 1307 |
+
head_mask=head_mask,
|
| 1308 |
+
inputs_embeds=inputs_embeds,
|
| 1309 |
+
output_attentions=output_attentions,
|
| 1310 |
+
output_hidden_states=output_hidden_states,
|
| 1311 |
+
return_dict=return_dict,
|
| 1312 |
+
training=training,
|
| 1313 |
+
)
|
| 1314 |
+
sequence_output = outputs[0]
|
| 1315 |
+
sequence_output = self.dropout(inputs=sequence_output, training=training)
|
| 1316 |
+
logits = self.classifier(inputs=sequence_output)
|
| 1317 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
| 1318 |
+
|
| 1319 |
+
if not return_dict:
|
| 1320 |
+
output = (logits,) + outputs[2:]
|
| 1321 |
+
|
| 1322 |
+
return ((loss,) + output) if loss is not None else output
|
| 1323 |
+
|
| 1324 |
+
return TFTokenClassifierOutput(
|
| 1325 |
+
loss=loss,
|
| 1326 |
+
logits=logits,
|
| 1327 |
+
hidden_states=outputs.hidden_states,
|
| 1328 |
+
attentions=outputs.attentions,
|
| 1329 |
+
)
|
| 1330 |
+
|
| 1331 |
+
def build(self, input_shape=None):
|
| 1332 |
+
if self.built:
|
| 1333 |
+
return
|
| 1334 |
+
self.built = True
|
| 1335 |
+
if getattr(self, "albert", None) is not None:
|
| 1336 |
+
with tf.name_scope(self.albert.name):
|
| 1337 |
+
self.albert.build(None)
|
| 1338 |
+
if getattr(self, "classifier", None) is not None:
|
| 1339 |
+
with tf.name_scope(self.classifier.name):
|
| 1340 |
+
self.classifier.build([None, None, self.config.hidden_size])
|
| 1341 |
+
|
| 1342 |
+
|
| 1343 |
+
@add_start_docstrings(
|
| 1344 |
+
"""
|
| 1345 |
+
Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1346 |
+
layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1347 |
+
""",
|
| 1348 |
+
ALBERT_START_DOCSTRING,
|
| 1349 |
+
)
|
| 1350 |
+
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
|
| 1351 |
+
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
| 1352 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
|
| 1353 |
+
|
| 1354 |
+
def __init__(self, config: AlbertConfig, *inputs, **kwargs):
|
| 1355 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1356 |
+
|
| 1357 |
+
self.num_labels = config.num_labels
|
| 1358 |
+
|
| 1359 |
+
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
|
| 1360 |
+
self.qa_outputs = keras.layers.Dense(
|
| 1361 |
+
units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
| 1362 |
+
)
|
| 1363 |
+
self.config = config
|
| 1364 |
+
|
| 1365 |
+
@unpack_inputs
|
| 1366 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1367 |
+
@add_code_sample_docstrings(
|
| 1368 |
+
checkpoint="vumichien/albert-base-v2-squad2",
|
| 1369 |
+
output_type=TFQuestionAnsweringModelOutput,
|
| 1370 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1371 |
+
qa_target_start_index=12,
|
| 1372 |
+
qa_target_end_index=13,
|
| 1373 |
+
expected_output="'a nice puppet'",
|
| 1374 |
+
expected_loss=7.36,
|
| 1375 |
+
)
|
| 1376 |
+
def call(
|
| 1377 |
+
self,
|
| 1378 |
+
input_ids: TFModelInputType | None = None,
|
| 1379 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1380 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1381 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1382 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1383 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1384 |
+
output_attentions: bool | None = None,
|
| 1385 |
+
output_hidden_states: bool | None = None,
|
| 1386 |
+
return_dict: bool | None = None,
|
| 1387 |
+
start_positions: np.ndarray | tf.Tensor | None = None,
|
| 1388 |
+
end_positions: np.ndarray | tf.Tensor | None = None,
|
| 1389 |
+
training: bool | None = False,
|
| 1390 |
+
) -> TFQuestionAnsweringModelOutput | tuple[tf.Tensor]:
|
| 1391 |
+
r"""
|
| 1392 |
+
start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
| 1393 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1394 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1395 |
+
are not taken into account for computing the loss.
|
| 1396 |
+
end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
| 1397 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1398 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1399 |
+
are not taken into account for computing the loss.
|
| 1400 |
+
"""
|
| 1401 |
+
outputs = self.albert(
|
| 1402 |
+
input_ids=input_ids,
|
| 1403 |
+
attention_mask=attention_mask,
|
| 1404 |
+
token_type_ids=token_type_ids,
|
| 1405 |
+
position_ids=position_ids,
|
| 1406 |
+
head_mask=head_mask,
|
| 1407 |
+
inputs_embeds=inputs_embeds,
|
| 1408 |
+
output_attentions=output_attentions,
|
| 1409 |
+
output_hidden_states=output_hidden_states,
|
| 1410 |
+
return_dict=return_dict,
|
| 1411 |
+
training=training,
|
| 1412 |
+
)
|
| 1413 |
+
sequence_output = outputs[0]
|
| 1414 |
+
logits = self.qa_outputs(inputs=sequence_output)
|
| 1415 |
+
start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
|
| 1416 |
+
start_logits = tf.squeeze(input=start_logits, axis=-1)
|
| 1417 |
+
end_logits = tf.squeeze(input=end_logits, axis=-1)
|
| 1418 |
+
loss = None
|
| 1419 |
+
|
| 1420 |
+
if start_positions is not None and end_positions is not None:
|
| 1421 |
+
labels = {"start_position": start_positions}
|
| 1422 |
+
labels["end_position"] = end_positions
|
| 1423 |
+
loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
|
| 1424 |
+
|
| 1425 |
+
if not return_dict:
|
| 1426 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 1427 |
+
|
| 1428 |
+
return ((loss,) + output) if loss is not None else output
|
| 1429 |
+
|
| 1430 |
+
return TFQuestionAnsweringModelOutput(
|
| 1431 |
+
loss=loss,
|
| 1432 |
+
start_logits=start_logits,
|
| 1433 |
+
end_logits=end_logits,
|
| 1434 |
+
hidden_states=outputs.hidden_states,
|
| 1435 |
+
attentions=outputs.attentions,
|
| 1436 |
+
)
|
| 1437 |
+
|
| 1438 |
+
def build(self, input_shape=None):
|
| 1439 |
+
if self.built:
|
| 1440 |
+
return
|
| 1441 |
+
self.built = True
|
| 1442 |
+
if getattr(self, "albert", None) is not None:
|
| 1443 |
+
with tf.name_scope(self.albert.name):
|
| 1444 |
+
self.albert.build(None)
|
| 1445 |
+
if getattr(self, "qa_outputs", None) is not None:
|
| 1446 |
+
with tf.name_scope(self.qa_outputs.name):
|
| 1447 |
+
self.qa_outputs.build([None, None, self.config.hidden_size])
|
| 1448 |
+
|
| 1449 |
+
|
| 1450 |
+
@add_start_docstrings(
|
| 1451 |
+
"""
|
| 1452 |
+
Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
| 1453 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
| 1454 |
+
""",
|
| 1455 |
+
ALBERT_START_DOCSTRING,
|
| 1456 |
+
)
|
| 1457 |
+
class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
| 1458 |
+
# names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
|
| 1459 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
|
| 1460 |
+
_keys_to_ignore_on_load_missing = [r"dropout"]
|
| 1461 |
+
|
| 1462 |
+
def __init__(self, config: AlbertConfig, *inputs, **kwargs):
|
| 1463 |
+
super().__init__(config, *inputs, **kwargs)
|
| 1464 |
+
|
| 1465 |
+
self.albert = TFAlbertMainLayer(config, name="albert")
|
| 1466 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 1467 |
+
self.classifier = keras.layers.Dense(
|
| 1468 |
+
units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
| 1469 |
+
)
|
| 1470 |
+
self.config = config
|
| 1471 |
+
|
| 1472 |
+
@unpack_inputs
|
| 1473 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
| 1474 |
+
@add_code_sample_docstrings(
|
| 1475 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1476 |
+
output_type=TFMultipleChoiceModelOutput,
|
| 1477 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1478 |
+
)
|
| 1479 |
+
def call(
|
| 1480 |
+
self,
|
| 1481 |
+
input_ids: TFModelInputType | None = None,
|
| 1482 |
+
attention_mask: np.ndarray | tf.Tensor | None = None,
|
| 1483 |
+
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
| 1484 |
+
position_ids: np.ndarray | tf.Tensor | None = None,
|
| 1485 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1486 |
+
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
| 1487 |
+
output_attentions: bool | None = None,
|
| 1488 |
+
output_hidden_states: bool | None = None,
|
| 1489 |
+
return_dict: bool | None = None,
|
| 1490 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 1491 |
+
training: bool | None = False,
|
| 1492 |
+
) -> TFMultipleChoiceModelOutput | tuple[tf.Tensor]:
|
| 1493 |
+
r"""
|
| 1494 |
+
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
|
| 1495 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
|
| 1496 |
+
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
|
| 1497 |
+
"""
|
| 1498 |
+
|
| 1499 |
+
if input_ids is not None:
|
| 1500 |
+
num_choices = shape_list(input_ids)[1]
|
| 1501 |
+
seq_length = shape_list(input_ids)[2]
|
| 1502 |
+
else:
|
| 1503 |
+
num_choices = shape_list(inputs_embeds)[1]
|
| 1504 |
+
seq_length = shape_list(inputs_embeds)[2]
|
| 1505 |
+
|
| 1506 |
+
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
| 1507 |
+
flat_attention_mask = (
|
| 1508 |
+
tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None
|
| 1509 |
+
)
|
| 1510 |
+
flat_token_type_ids = (
|
| 1511 |
+
tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None
|
| 1512 |
+
)
|
| 1513 |
+
flat_position_ids = (
|
| 1514 |
+
tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None
|
| 1515 |
+
)
|
| 1516 |
+
flat_inputs_embeds = (
|
| 1517 |
+
tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3]))
|
| 1518 |
+
if inputs_embeds is not None
|
| 1519 |
+
else None
|
| 1520 |
+
)
|
| 1521 |
+
outputs = self.albert(
|
| 1522 |
+
input_ids=flat_input_ids,
|
| 1523 |
+
attention_mask=flat_attention_mask,
|
| 1524 |
+
token_type_ids=flat_token_type_ids,
|
| 1525 |
+
position_ids=flat_position_ids,
|
| 1526 |
+
head_mask=head_mask,
|
| 1527 |
+
inputs_embeds=flat_inputs_embeds,
|
| 1528 |
+
output_attentions=output_attentions,
|
| 1529 |
+
output_hidden_states=output_hidden_states,
|
| 1530 |
+
return_dict=return_dict,
|
| 1531 |
+
training=training,
|
| 1532 |
+
)
|
| 1533 |
+
pooled_output = outputs[1]
|
| 1534 |
+
pooled_output = self.dropout(inputs=pooled_output, training=training)
|
| 1535 |
+
logits = self.classifier(inputs=pooled_output)
|
| 1536 |
+
reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
|
| 1537 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
|
| 1538 |
+
|
| 1539 |
+
if not return_dict:
|
| 1540 |
+
output = (reshaped_logits,) + outputs[2:]
|
| 1541 |
+
return ((loss,) + output) if loss is not None else output
|
| 1542 |
+
|
| 1543 |
+
return TFMultipleChoiceModelOutput(
|
| 1544 |
+
loss=loss,
|
| 1545 |
+
logits=reshaped_logits,
|
| 1546 |
+
hidden_states=outputs.hidden_states,
|
| 1547 |
+
attentions=outputs.attentions,
|
| 1548 |
+
)
|
| 1549 |
+
|
| 1550 |
+
def build(self, input_shape=None):
|
| 1551 |
+
if self.built:
|
| 1552 |
+
return
|
| 1553 |
+
self.built = True
|
| 1554 |
+
if getattr(self, "albert", None) is not None:
|
| 1555 |
+
with tf.name_scope(self.albert.name):
|
| 1556 |
+
self.albert.build(None)
|
| 1557 |
+
if getattr(self, "classifier", None) is not None:
|
| 1558 |
+
with tf.name_scope(self.classifier.name):
|
| 1559 |
+
self.classifier.build([None, None, self.config.hidden_size])
|
| 1560 |
+
|
| 1561 |
+
|
| 1562 |
+
__all__ = [
|
| 1563 |
+
"TFAlbertPreTrainedModel",
|
| 1564 |
+
"TFAlbertModel",
|
| 1565 |
+
"TFAlbertForPreTraining",
|
| 1566 |
+
"TFAlbertForMaskedLM",
|
| 1567 |
+
"TFAlbertForSequenceClassification",
|
| 1568 |
+
"TFAlbertForTokenClassification",
|
| 1569 |
+
"TFAlbertForQuestionAnswering",
|
| 1570 |
+
"TFAlbertForMultipleChoice",
|
| 1571 |
+
"TFAlbertMainLayer",
|
| 1572 |
+
]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/tokenization_albert.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for ALBERT model."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import unicodedata
|
| 19 |
+
from shutil import copyfile
|
| 20 |
+
from typing import Any, Optional
|
| 21 |
+
|
| 22 |
+
import sentencepiece as spm
|
| 23 |
+
|
| 24 |
+
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 25 |
+
from ...utils import logging
|
| 26 |
+
from ...utils.import_utils import requires
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
logger = logging.get_logger(__name__)
|
| 30 |
+
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
SPIECE_UNDERLINE = "▁"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@requires(backends=("sentencepiece",))
|
| 37 |
+
class AlbertTokenizer(PreTrainedTokenizer):
|
| 38 |
+
"""
|
| 39 |
+
Construct an ALBERT tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
|
| 40 |
+
|
| 41 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 42 |
+
this superclass for more information regarding those methods.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
vocab_file (`str`):
|
| 46 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
|
| 47 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
| 48 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
| 49 |
+
Whether or not to lowercase the input when tokenizing.
|
| 50 |
+
remove_space (`bool`, *optional*, defaults to `True`):
|
| 51 |
+
Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
|
| 52 |
+
keep_accents (`bool`, *optional*, defaults to `False`):
|
| 53 |
+
Whether or not to keep accents when tokenizing.
|
| 54 |
+
bos_token (`str`, *optional*, defaults to `"[CLS]"`):
|
| 55 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 56 |
+
|
| 57 |
+
<Tip>
|
| 58 |
+
|
| 59 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
| 60 |
+
sequence. The token used is the `cls_token`.
|
| 61 |
+
|
| 62 |
+
</Tip>
|
| 63 |
+
|
| 64 |
+
eos_token (`str`, *optional*, defaults to `"[SEP]"`):
|
| 65 |
+
The end of sequence token.
|
| 66 |
+
|
| 67 |
+
<Tip>
|
| 68 |
+
|
| 69 |
+
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
| 70 |
+
The token used is the `sep_token`.
|
| 71 |
+
|
| 72 |
+
</Tip>
|
| 73 |
+
|
| 74 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 75 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 76 |
+
token instead.
|
| 77 |
+
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
|
| 78 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 79 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 80 |
+
token of a sequence built with special tokens.
|
| 81 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 82 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 83 |
+
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
|
| 84 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 85 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 86 |
+
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
|
| 87 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 88 |
+
modeling. This is the token which the model will try to predict.
|
| 89 |
+
sp_model_kwargs (`dict`, *optional*):
|
| 90 |
+
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
| 91 |
+
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
| 92 |
+
to set:
|
| 93 |
+
|
| 94 |
+
- `enable_sampling`: Enable subword regularization.
|
| 95 |
+
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
| 96 |
+
|
| 97 |
+
- `nbest_size = {0,1}`: No sampling is performed.
|
| 98 |
+
- `nbest_size > 1`: samples from the nbest_size results.
|
| 99 |
+
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
| 100 |
+
using forward-filtering-and-backward-sampling algorithm.
|
| 101 |
+
|
| 102 |
+
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
| 103 |
+
BPE-dropout.
|
| 104 |
+
|
| 105 |
+
Attributes:
|
| 106 |
+
sp_model (`SentencePieceProcessor`):
|
| 107 |
+
The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 111 |
+
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
vocab_file,
|
| 115 |
+
do_lower_case=True,
|
| 116 |
+
remove_space=True,
|
| 117 |
+
keep_accents=False,
|
| 118 |
+
bos_token="[CLS]",
|
| 119 |
+
eos_token="[SEP]",
|
| 120 |
+
unk_token="<unk>",
|
| 121 |
+
sep_token="[SEP]",
|
| 122 |
+
pad_token="<pad>",
|
| 123 |
+
cls_token="[CLS]",
|
| 124 |
+
mask_token="[MASK]",
|
| 125 |
+
sp_model_kwargs: Optional[dict[str, Any]] = None,
|
| 126 |
+
**kwargs,
|
| 127 |
+
) -> None:
|
| 128 |
+
# Mask token behave like a normal word, i.e. include the space before it and
|
| 129 |
+
# is included in the raw text, there should be a match in a non-normalized sentence.
|
| 130 |
+
mask_token = (
|
| 131 |
+
AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
|
| 132 |
+
if isinstance(mask_token, str)
|
| 133 |
+
else mask_token
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 137 |
+
|
| 138 |
+
self.do_lower_case = do_lower_case
|
| 139 |
+
self.remove_space = remove_space
|
| 140 |
+
self.keep_accents = keep_accents
|
| 141 |
+
self.vocab_file = vocab_file
|
| 142 |
+
|
| 143 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 144 |
+
self.sp_model.Load(vocab_file)
|
| 145 |
+
|
| 146 |
+
super().__init__(
|
| 147 |
+
do_lower_case=do_lower_case,
|
| 148 |
+
remove_space=remove_space,
|
| 149 |
+
keep_accents=keep_accents,
|
| 150 |
+
bos_token=bos_token,
|
| 151 |
+
eos_token=eos_token,
|
| 152 |
+
unk_token=unk_token,
|
| 153 |
+
sep_token=sep_token,
|
| 154 |
+
pad_token=pad_token,
|
| 155 |
+
cls_token=cls_token,
|
| 156 |
+
mask_token=mask_token,
|
| 157 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
| 158 |
+
**kwargs,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
@property
|
| 162 |
+
def vocab_size(self) -> int:
|
| 163 |
+
return len(self.sp_model)
|
| 164 |
+
|
| 165 |
+
def get_vocab(self) -> dict[str, int]:
|
| 166 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 167 |
+
vocab.update(self.added_tokens_encoder)
|
| 168 |
+
return vocab
|
| 169 |
+
|
| 170 |
+
def __getstate__(self):
|
| 171 |
+
state = self.__dict__.copy()
|
| 172 |
+
state["sp_model"] = None
|
| 173 |
+
return state
|
| 174 |
+
|
| 175 |
+
def __setstate__(self, d):
|
| 176 |
+
self.__dict__ = d
|
| 177 |
+
|
| 178 |
+
# for backward compatibility
|
| 179 |
+
if not hasattr(self, "sp_model_kwargs"):
|
| 180 |
+
self.sp_model_kwargs = {}
|
| 181 |
+
|
| 182 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 183 |
+
self.sp_model.Load(self.vocab_file)
|
| 184 |
+
|
| 185 |
+
def preprocess_text(self, inputs):
|
| 186 |
+
if self.remove_space:
|
| 187 |
+
outputs = " ".join(inputs.strip().split())
|
| 188 |
+
else:
|
| 189 |
+
outputs = inputs
|
| 190 |
+
outputs = outputs.replace("``", '"').replace("''", '"')
|
| 191 |
+
|
| 192 |
+
if not self.keep_accents:
|
| 193 |
+
outputs = unicodedata.normalize("NFKD", outputs)
|
| 194 |
+
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
|
| 195 |
+
if self.do_lower_case:
|
| 196 |
+
outputs = outputs.lower()
|
| 197 |
+
|
| 198 |
+
return outputs
|
| 199 |
+
|
| 200 |
+
def _tokenize(self, text: str) -> list[str]:
|
| 201 |
+
"""Tokenize a string."""
|
| 202 |
+
text = self.preprocess_text(text)
|
| 203 |
+
pieces = self.sp_model.encode(text, out_type=str)
|
| 204 |
+
new_pieces = []
|
| 205 |
+
for piece in pieces:
|
| 206 |
+
if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
|
| 207 |
+
# Logic to handle special cases see https://github.com/google-research/bert/blob/master/README.md#tokenization
|
| 208 |
+
# `9,9` -> ['▁9', ',', '9'] instead of [`_9,`, '9']
|
| 209 |
+
cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
|
| 210 |
+
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
|
| 211 |
+
if len(cur_pieces[0]) == 1:
|
| 212 |
+
cur_pieces = cur_pieces[1:]
|
| 213 |
+
else:
|
| 214 |
+
cur_pieces[0] = cur_pieces[0][1:]
|
| 215 |
+
cur_pieces.append(piece[-1])
|
| 216 |
+
new_pieces.extend(cur_pieces)
|
| 217 |
+
else:
|
| 218 |
+
new_pieces.append(piece)
|
| 219 |
+
|
| 220 |
+
return new_pieces
|
| 221 |
+
|
| 222 |
+
def _convert_token_to_id(self, token):
|
| 223 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 224 |
+
return self.sp_model.PieceToId(token)
|
| 225 |
+
|
| 226 |
+
def _convert_id_to_token(self, index):
|
| 227 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 228 |
+
return self.sp_model.IdToPiece(index)
|
| 229 |
+
|
| 230 |
+
def convert_tokens_to_string(self, tokens):
|
| 231 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 232 |
+
current_sub_tokens = []
|
| 233 |
+
out_string = ""
|
| 234 |
+
prev_is_special = False
|
| 235 |
+
for token in tokens:
|
| 236 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
| 237 |
+
if token in self.all_special_tokens:
|
| 238 |
+
if not prev_is_special:
|
| 239 |
+
out_string += " "
|
| 240 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
| 241 |
+
prev_is_special = True
|
| 242 |
+
current_sub_tokens = []
|
| 243 |
+
else:
|
| 244 |
+
current_sub_tokens.append(token)
|
| 245 |
+
prev_is_special = False
|
| 246 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 247 |
+
return out_string.strip()
|
| 248 |
+
|
| 249 |
+
def build_inputs_with_special_tokens(
|
| 250 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
| 251 |
+
) -> list[int]:
|
| 252 |
+
"""
|
| 253 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 254 |
+
adding special tokens. An ALBERT sequence has the following format:
|
| 255 |
+
|
| 256 |
+
- single sequence: `[CLS] X [SEP]`
|
| 257 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
token_ids_0 (`List[int]`):
|
| 261 |
+
List of IDs to which the special tokens will be added.
|
| 262 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 263 |
+
Optional second list of IDs for sequence pairs.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 267 |
+
"""
|
| 268 |
+
sep = [self.sep_token_id]
|
| 269 |
+
cls = [self.cls_token_id]
|
| 270 |
+
if token_ids_1 is None:
|
| 271 |
+
return cls + token_ids_0 + sep
|
| 272 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 273 |
+
|
| 274 |
+
def get_special_tokens_mask(
|
| 275 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
|
| 276 |
+
) -> list[int]:
|
| 277 |
+
"""
|
| 278 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 279 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
token_ids_0 (`List[int]`):
|
| 283 |
+
List of IDs.
|
| 284 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 285 |
+
Optional second list of IDs for sequence pairs.
|
| 286 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 287 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
if already_has_special_tokens:
|
| 294 |
+
return super().get_special_tokens_mask(
|
| 295 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if token_ids_1 is not None:
|
| 299 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 300 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 301 |
+
|
| 302 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
|
| 303 |
+
if not os.path.isdir(save_directory):
|
| 304 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 305 |
+
return
|
| 306 |
+
out_vocab_file = os.path.join(
|
| 307 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 311 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 312 |
+
elif not os.path.isfile(self.vocab_file):
|
| 313 |
+
with open(out_vocab_file, "wb") as fi:
|
| 314 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 315 |
+
fi.write(content_spiece_model)
|
| 316 |
+
|
| 317 |
+
return (out_vocab_file,)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
__all__ = ["AlbertTokenizer"]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/albert/tokenization_albert_fast.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes for ALBERT model."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from shutil import copyfile
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
from ...tokenization_utils import AddedToken
|
| 22 |
+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
| 23 |
+
from ...utils import is_sentencepiece_available, logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_sentencepiece_available():
|
| 27 |
+
from .tokenization_albert import AlbertTokenizer
|
| 28 |
+
else:
|
| 29 |
+
AlbertTokenizer = None
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
SPIECE_UNDERLINE = "▁"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AlbertTokenizerFast(PreTrainedTokenizerFast):
|
| 39 |
+
"""
|
| 40 |
+
Construct a "fast" ALBERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on
|
| 41 |
+
[Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models). This
|
| 42 |
+
tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should refer to
|
| 43 |
+
this superclass for more information regarding those methods
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
vocab_file (`str`):
|
| 47 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
|
| 48 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
| 49 |
+
do_lower_case (`bool`, *optional*, defaults to `True`):
|
| 50 |
+
Whether or not to lowercase the input when tokenizing.
|
| 51 |
+
remove_space (`bool`, *optional*, defaults to `True`):
|
| 52 |
+
Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
|
| 53 |
+
keep_accents (`bool`, *optional*, defaults to `False`):
|
| 54 |
+
Whether or not to keep accents when tokenizing.
|
| 55 |
+
bos_token (`str`, *optional*, defaults to `"[CLS]"`):
|
| 56 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 57 |
+
|
| 58 |
+
<Tip>
|
| 59 |
+
|
| 60 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
| 61 |
+
sequence. The token used is the `cls_token`.
|
| 62 |
+
|
| 63 |
+
</Tip>
|
| 64 |
+
|
| 65 |
+
eos_token (`str`, *optional*, defaults to `"[SEP]"`):
|
| 66 |
+
The end of sequence token. .. note:: When building a sequence using special tokens, this is not the token
|
| 67 |
+
that is used for the end of sequence. The token used is the `sep_token`.
|
| 68 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 69 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 70 |
+
token instead.
|
| 71 |
+
sep_token (`str`, *optional*, defaults to `"[SEP]"`):
|
| 72 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 73 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 74 |
+
token of a sequence built with special tokens.
|
| 75 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 76 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 77 |
+
cls_token (`str`, *optional*, defaults to `"[CLS]"`):
|
| 78 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 79 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 80 |
+
mask_token (`str`, *optional*, defaults to `"[MASK]"`):
|
| 81 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 82 |
+
modeling. This is the token which the model will try to predict.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 86 |
+
slow_tokenizer_class = AlbertTokenizer
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
vocab_file=None,
|
| 91 |
+
tokenizer_file=None,
|
| 92 |
+
do_lower_case=True,
|
| 93 |
+
remove_space=True,
|
| 94 |
+
keep_accents=False,
|
| 95 |
+
bos_token="[CLS]",
|
| 96 |
+
eos_token="[SEP]",
|
| 97 |
+
unk_token="<unk>",
|
| 98 |
+
sep_token="[SEP]",
|
| 99 |
+
pad_token="<pad>",
|
| 100 |
+
cls_token="[CLS]",
|
| 101 |
+
mask_token="[MASK]",
|
| 102 |
+
**kwargs,
|
| 103 |
+
):
|
| 104 |
+
# Mask token behave like a normal word, i.e. include the space before it and
|
| 105 |
+
# is included in the raw text, there should be a match in a non-normalized sentence.
|
| 106 |
+
mask_token = (
|
| 107 |
+
AddedToken(mask_token, lstrip=True, rstrip=False, normalized=False)
|
| 108 |
+
if isinstance(mask_token, str)
|
| 109 |
+
else mask_token
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
super().__init__(
|
| 113 |
+
vocab_file,
|
| 114 |
+
tokenizer_file=tokenizer_file,
|
| 115 |
+
do_lower_case=do_lower_case,
|
| 116 |
+
remove_space=remove_space,
|
| 117 |
+
keep_accents=keep_accents,
|
| 118 |
+
bos_token=bos_token,
|
| 119 |
+
eos_token=eos_token,
|
| 120 |
+
unk_token=unk_token,
|
| 121 |
+
sep_token=sep_token,
|
| 122 |
+
pad_token=pad_token,
|
| 123 |
+
cls_token=cls_token,
|
| 124 |
+
mask_token=mask_token,
|
| 125 |
+
**kwargs,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self.do_lower_case = do_lower_case
|
| 129 |
+
self.remove_space = remove_space
|
| 130 |
+
self.keep_accents = keep_accents
|
| 131 |
+
self.vocab_file = vocab_file
|
| 132 |
+
|
| 133 |
+
def build_inputs_with_special_tokens(
|
| 134 |
+
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
| 135 |
+
) -> list[int]:
|
| 136 |
+
"""
|
| 137 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 138 |
+
adding special tokens. An ALBERT sequence has the following format:
|
| 139 |
+
|
| 140 |
+
- single sequence: `[CLS] X [SEP]`
|
| 141 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
token_ids_0 (`List[int]`):
|
| 145 |
+
List of IDs to which the special tokens will be added
|
| 146 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 147 |
+
Optional second list of IDs for sequence pairs.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
`List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 151 |
+
"""
|
| 152 |
+
sep = [self.sep_token_id]
|
| 153 |
+
cls = [self.cls_token_id]
|
| 154 |
+
if token_ids_1 is None:
|
| 155 |
+
return cls + token_ids_0 + sep
|
| 156 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 157 |
+
|
| 158 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
|
| 159 |
+
if not self.can_save_slow_tokenizer:
|
| 160 |
+
raise ValueError(
|
| 161 |
+
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
| 162 |
+
"tokenizer."
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
if not os.path.isdir(save_directory):
|
| 166 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 167 |
+
return
|
| 168 |
+
out_vocab_file = os.path.join(
|
| 169 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
| 173 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 174 |
+
|
| 175 |
+
return (out_vocab_file,)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
__all__ = ["AlbertTokenizerFast"]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .auto_factory import *
|
| 22 |
+
from .configuration_auto import *
|
| 23 |
+
from .feature_extraction_auto import *
|
| 24 |
+
from .image_processing_auto import *
|
| 25 |
+
from .modeling_auto import *
|
| 26 |
+
from .modeling_flax_auto import *
|
| 27 |
+
from .modeling_tf_auto import *
|
| 28 |
+
from .processing_auto import *
|
| 29 |
+
from .tokenization_auto import *
|
| 30 |
+
from .video_processing_auto import *
|
| 31 |
+
else:
|
| 32 |
+
import sys
|
| 33 |
+
|
| 34 |
+
_file = globals()["__file__"]
|
| 35 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py
ADDED
|
@@ -0,0 +1,882 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Factory function to build auto-model classes."""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
import importlib
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import warnings
|
| 22 |
+
from collections import OrderedDict
|
| 23 |
+
from collections.abc import Iterator
|
| 24 |
+
from typing import Any, TypeVar, Union
|
| 25 |
+
|
| 26 |
+
from ...configuration_utils import PretrainedConfig
|
| 27 |
+
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
| 28 |
+
from ...utils import (
|
| 29 |
+
CONFIG_NAME,
|
| 30 |
+
cached_file,
|
| 31 |
+
copy_func,
|
| 32 |
+
extract_commit_hash,
|
| 33 |
+
find_adapter_config_file,
|
| 34 |
+
is_peft_available,
|
| 35 |
+
is_torch_available,
|
| 36 |
+
logging,
|
| 37 |
+
requires_backends,
|
| 38 |
+
)
|
| 39 |
+
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if is_torch_available():
|
| 43 |
+
from ...generation import GenerationMixin
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__)
|
| 47 |
+
|
| 48 |
+
_T = TypeVar("_T")
|
| 49 |
+
# Tokenizers will depend on packages installed, too much variance and there are no common base or Protocol
|
| 50 |
+
_LazyAutoMappingValue = tuple[Union[type[Any], None], Union[type[Any], None]]
|
| 51 |
+
|
| 52 |
+
CLASS_DOCSTRING = """
|
| 53 |
+
This is a generic model class that will be instantiated as one of the model classes of the library when created
|
| 54 |
+
with the [`~BaseAutoModelClass.from_pretrained`] class method or the [`~BaseAutoModelClass.from_config`] class
|
| 55 |
+
method.
|
| 56 |
+
|
| 57 |
+
This class cannot be instantiated directly using `__init__()` (throws an error).
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
FROM_CONFIG_DOCSTRING = """
|
| 61 |
+
Instantiates one of the model classes of the library from a configuration.
|
| 62 |
+
|
| 63 |
+
Note:
|
| 64 |
+
Loading a model from its configuration file does **not** load the model weights. It only affects the
|
| 65 |
+
model's configuration. Use [`~BaseAutoModelClass.from_pretrained`] to load the model weights.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
config ([`PretrainedConfig`]):
|
| 69 |
+
The model class to instantiate is selected based on the configuration class:
|
| 70 |
+
|
| 71 |
+
List options
|
| 72 |
+
attn_implementation (`str`, *optional*):
|
| 73 |
+
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
|
| 74 |
+
|
| 75 |
+
Examples:
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
>>> from transformers import AutoConfig, BaseAutoModelClass
|
| 79 |
+
|
| 80 |
+
>>> # Download configuration from huggingface.co and cache.
|
| 81 |
+
>>> config = AutoConfig.from_pretrained("checkpoint_placeholder")
|
| 82 |
+
>>> model = BaseAutoModelClass.from_config(config)
|
| 83 |
+
```
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
FROM_PRETRAINED_TORCH_DOCSTRING = """
|
| 87 |
+
Instantiate one of the model classes of the library from a pretrained model.
|
| 88 |
+
|
| 89 |
+
The model class to instantiate is selected based on the `model_type` property of the config object (either
|
| 90 |
+
passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
|
| 91 |
+
falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 92 |
+
|
| 93 |
+
List options
|
| 94 |
+
|
| 95 |
+
The model is set in evaluation mode by default using `model.eval()` (so for instance, dropout modules are
|
| 96 |
+
deactivated). To train the model, you should first set it back in training mode with `model.train()`
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 100 |
+
Can be either:
|
| 101 |
+
|
| 102 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
| 103 |
+
- A path to a *directory* containing model weights saved using
|
| 104 |
+
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
| 105 |
+
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
|
| 106 |
+
this case, `from_tf` should be set to `True` and a configuration object should be provided as
|
| 107 |
+
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
|
| 108 |
+
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
| 109 |
+
model_args (additional positional arguments, *optional*):
|
| 110 |
+
Will be passed along to the underlying model `__init__()` method.
|
| 111 |
+
config ([`PretrainedConfig`], *optional*):
|
| 112 |
+
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
| 113 |
+
be automatically loaded when:
|
| 114 |
+
|
| 115 |
+
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
|
| 116 |
+
model).
|
| 117 |
+
- The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
|
| 118 |
+
save directory.
|
| 119 |
+
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
| 120 |
+
configuration JSON file named *config.json* is found in the directory.
|
| 121 |
+
state_dict (*dict[str, torch.Tensor]*, *optional*):
|
| 122 |
+
A state dictionary to use instead of a state dictionary loaded from saved weights file.
|
| 123 |
+
|
| 124 |
+
This option can be used if you want to create a model from a pretrained configuration but load your own
|
| 125 |
+
weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and
|
| 126 |
+
[`~PreTrainedModel.from_pretrained`] is not a simpler option.
|
| 127 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 128 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
| 129 |
+
standard cache should not be used.
|
| 130 |
+
from_tf (`bool`, *optional*, defaults to `False`):
|
| 131 |
+
Load the model weights from a TensorFlow checkpoint save file (see docstring of
|
| 132 |
+
`pretrained_model_name_or_path` argument).
|
| 133 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 134 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 135 |
+
cached versions if they exist.
|
| 136 |
+
resume_download:
|
| 137 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 138 |
+
Will be removed in v5 of Transformers.
|
| 139 |
+
proxies (`dict[str, str]`, *optional*):
|
| 140 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 141 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 142 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
| 143 |
+
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
| 144 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
| 145 |
+
Whether or not to only look at local files (e.g., not try downloading the model).
|
| 146 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 147 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 148 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 149 |
+
identifier allowed by git.
|
| 150 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 151 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 152 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 153 |
+
execute code present on the Hub on your local machine.
|
| 154 |
+
code_revision (`str`, *optional*, defaults to `"main"`):
|
| 155 |
+
The specific revision to use for the code on the Hub, if the code leaves in a different repository than
|
| 156 |
+
the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
|
| 157 |
+
system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
|
| 158 |
+
allowed by git.
|
| 159 |
+
kwargs (additional keyword arguments, *optional*):
|
| 160 |
+
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
| 161 |
+
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
| 162 |
+
automatically loaded:
|
| 163 |
+
|
| 164 |
+
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
|
| 165 |
+
underlying model's `__init__` method (we assume all relevant updates to the configuration have
|
| 166 |
+
already been done)
|
| 167 |
+
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
|
| 168 |
+
initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
|
| 169 |
+
corresponds to a configuration attribute will be used to override said attribute with the
|
| 170 |
+
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
|
| 171 |
+
will be passed to the underlying model's `__init__` function.
|
| 172 |
+
|
| 173 |
+
Examples:
|
| 174 |
+
|
| 175 |
+
```python
|
| 176 |
+
>>> from transformers import AutoConfig, BaseAutoModelClass
|
| 177 |
+
|
| 178 |
+
>>> # Download model and configuration from huggingface.co and cache.
|
| 179 |
+
>>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
|
| 180 |
+
|
| 181 |
+
>>> # Update configuration during loading
|
| 182 |
+
>>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
|
| 183 |
+
>>> model.config.output_attentions
|
| 184 |
+
True
|
| 185 |
+
|
| 186 |
+
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
| 187 |
+
>>> config = AutoConfig.from_pretrained("./tf_model/shortcut_placeholder_tf_model_config.json")
|
| 188 |
+
>>> model = BaseAutoModelClass.from_pretrained(
|
| 189 |
+
... "./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index", from_tf=True, config=config
|
| 190 |
+
... )
|
| 191 |
+
```
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
FROM_PRETRAINED_TF_DOCSTRING = """
|
| 195 |
+
Instantiate one of the model classes of the library from a pretrained model.
|
| 196 |
+
|
| 197 |
+
The model class to instantiate is selected based on the `model_type` property of the config object (either
|
| 198 |
+
passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
|
| 199 |
+
falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 200 |
+
|
| 201 |
+
List options
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 205 |
+
Can be either:
|
| 206 |
+
|
| 207 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
| 208 |
+
- A path to a *directory* containing model weights saved using
|
| 209 |
+
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
| 210 |
+
- A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
|
| 211 |
+
case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
|
| 212 |
+
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
|
| 213 |
+
using the provided conversion scripts and loading the TensorFlow model afterwards.
|
| 214 |
+
model_args (additional positional arguments, *optional*):
|
| 215 |
+
Will be passed along to the underlying model `__init__()` method.
|
| 216 |
+
config ([`PretrainedConfig`], *optional*):
|
| 217 |
+
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
| 218 |
+
be automatically loaded when:
|
| 219 |
+
|
| 220 |
+
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
|
| 221 |
+
model).
|
| 222 |
+
- The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
|
| 223 |
+
save directory.
|
| 224 |
+
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
| 225 |
+
configuration JSON file named *config.json* is found in the directory.
|
| 226 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 227 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
| 228 |
+
standard cache should not be used.
|
| 229 |
+
from_pt (`bool`, *optional*, defaults to `False`):
|
| 230 |
+
Load the model weights from a PyTorch checkpoint save file (see docstring of
|
| 231 |
+
`pretrained_model_name_or_path` argument).
|
| 232 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 233 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 234 |
+
cached versions if they exist.
|
| 235 |
+
resume_download:
|
| 236 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 237 |
+
Will be removed in v5 of Transformers.
|
| 238 |
+
proxies (`dict[str, str]`, *optional*):
|
| 239 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 240 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 241 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
| 242 |
+
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
| 243 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
| 244 |
+
Whether or not to only look at local files (e.g., not try downloading the model).
|
| 245 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 246 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 247 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 248 |
+
identifier allowed by git.
|
| 249 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 250 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 251 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 252 |
+
execute code present on the Hub on your local machine.
|
| 253 |
+
code_revision (`str`, *optional*, defaults to `"main"`):
|
| 254 |
+
The specific revision to use for the code on the Hub, if the code leaves in a different repository than
|
| 255 |
+
the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
|
| 256 |
+
system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
|
| 257 |
+
allowed by git.
|
| 258 |
+
kwargs (additional keyword arguments, *optional*):
|
| 259 |
+
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
| 260 |
+
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
| 261 |
+
automatically loaded:
|
| 262 |
+
|
| 263 |
+
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
|
| 264 |
+
underlying model's `__init__` method (we assume all relevant updates to the configuration have
|
| 265 |
+
already been done)
|
| 266 |
+
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
|
| 267 |
+
initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
|
| 268 |
+
corresponds to a configuration attribute will be used to override said attribute with the
|
| 269 |
+
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
|
| 270 |
+
will be passed to the underlying model's `__init__` function.
|
| 271 |
+
|
| 272 |
+
Examples:
|
| 273 |
+
|
| 274 |
+
```python
|
| 275 |
+
>>> from transformers import AutoConfig, BaseAutoModelClass
|
| 276 |
+
|
| 277 |
+
>>> # Download model and configuration from huggingface.co and cache.
|
| 278 |
+
>>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
|
| 279 |
+
|
| 280 |
+
>>> # Update configuration during loading
|
| 281 |
+
>>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
|
| 282 |
+
>>> model.config.output_attentions
|
| 283 |
+
True
|
| 284 |
+
|
| 285 |
+
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
|
| 286 |
+
>>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json")
|
| 287 |
+
>>> model = BaseAutoModelClass.from_pretrained(
|
| 288 |
+
... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config
|
| 289 |
+
... )
|
| 290 |
+
```
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
FROM_PRETRAINED_FLAX_DOCSTRING = """
|
| 294 |
+
Instantiate one of the model classes of the library from a pretrained model.
|
| 295 |
+
|
| 296 |
+
The model class to instantiate is selected based on the `model_type` property of the config object (either
|
| 297 |
+
passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
|
| 298 |
+
falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 299 |
+
|
| 300 |
+
List options
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 304 |
+
Can be either:
|
| 305 |
+
|
| 306 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
| 307 |
+
- A path to a *directory* containing model weights saved using
|
| 308 |
+
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
| 309 |
+
- A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
|
| 310 |
+
case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
|
| 311 |
+
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
|
| 312 |
+
using the provided conversion scripts and loading the TensorFlow model afterwards.
|
| 313 |
+
model_args (additional positional arguments, *optional*):
|
| 314 |
+
Will be passed along to the underlying model `__init__()` method.
|
| 315 |
+
config ([`PretrainedConfig`], *optional*):
|
| 316 |
+
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
| 317 |
+
be automatically loaded when:
|
| 318 |
+
|
| 319 |
+
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
|
| 320 |
+
model).
|
| 321 |
+
- The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
|
| 322 |
+
save directory.
|
| 323 |
+
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
| 324 |
+
configuration JSON file named *config.json* is found in the directory.
|
| 325 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 326 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
| 327 |
+
standard cache should not be used.
|
| 328 |
+
from_pt (`bool`, *optional*, defaults to `False`):
|
| 329 |
+
Load the model weights from a PyTorch checkpoint save file (see docstring of
|
| 330 |
+
`pretrained_model_name_or_path` argument).
|
| 331 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 332 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 333 |
+
cached versions if they exist.
|
| 334 |
+
resume_download:
|
| 335 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 336 |
+
Will be removed in v5 of Transformers.
|
| 337 |
+
proxies (`dict[str, str]`, *optional*):
|
| 338 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 339 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 340 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
| 341 |
+
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
| 342 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
| 343 |
+
Whether or not to only look at local files (e.g., not try downloading the model).
|
| 344 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 345 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 346 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 347 |
+
identifier allowed by git.
|
| 348 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 349 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 350 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 351 |
+
execute code present on the Hub on your local machine.
|
| 352 |
+
code_revision (`str`, *optional*, defaults to `"main"`):
|
| 353 |
+
The specific revision to use for the code on the Hub, if the code leaves in a different repository than
|
| 354 |
+
the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
|
| 355 |
+
system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
|
| 356 |
+
allowed by git.
|
| 357 |
+
kwargs (additional keyword arguments, *optional*):
|
| 358 |
+
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
| 359 |
+
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
| 360 |
+
automatically loaded:
|
| 361 |
+
|
| 362 |
+
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
|
| 363 |
+
underlying model's `__init__` method (we assume all relevant updates to the configuration have
|
| 364 |
+
already been done)
|
| 365 |
+
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
|
| 366 |
+
initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
|
| 367 |
+
corresponds to a configuration attribute will be used to override said attribute with the
|
| 368 |
+
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
|
| 369 |
+
will be passed to the underlying model's `__init__` function.
|
| 370 |
+
|
| 371 |
+
Examples:
|
| 372 |
+
|
| 373 |
+
```python
|
| 374 |
+
>>> from transformers import AutoConfig, BaseAutoModelClass
|
| 375 |
+
|
| 376 |
+
>>> # Download model and configuration from huggingface.co and cache.
|
| 377 |
+
>>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
|
| 378 |
+
|
| 379 |
+
>>> # Update configuration during loading
|
| 380 |
+
>>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
|
| 381 |
+
>>> model.config.output_attentions
|
| 382 |
+
True
|
| 383 |
+
|
| 384 |
+
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
|
| 385 |
+
>>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json")
|
| 386 |
+
>>> model = BaseAutoModelClass.from_pretrained(
|
| 387 |
+
... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config
|
| 388 |
+
... )
|
| 389 |
+
```
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def _get_model_class(config, model_mapping):
|
| 394 |
+
supported_models = model_mapping[type(config)]
|
| 395 |
+
if not isinstance(supported_models, (list, tuple)):
|
| 396 |
+
return supported_models
|
| 397 |
+
|
| 398 |
+
name_to_model = {model.__name__: model for model in supported_models}
|
| 399 |
+
architectures = getattr(config, "architectures", [])
|
| 400 |
+
for arch in architectures:
|
| 401 |
+
if arch in name_to_model:
|
| 402 |
+
return name_to_model[arch]
|
| 403 |
+
elif f"TF{arch}" in name_to_model:
|
| 404 |
+
return name_to_model[f"TF{arch}"]
|
| 405 |
+
elif f"Flax{arch}" in name_to_model:
|
| 406 |
+
return name_to_model[f"Flax{arch}"]
|
| 407 |
+
|
| 408 |
+
# If not architecture is set in the config or match the supported models, the first element of the tuple is the
|
| 409 |
+
# defaults.
|
| 410 |
+
return supported_models[0]
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class _BaseAutoModelClass:
|
| 414 |
+
# Base class for auto models.
|
| 415 |
+
_model_mapping = None
|
| 416 |
+
|
| 417 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 418 |
+
raise OSError(
|
| 419 |
+
f"{self.__class__.__name__} is designed to be instantiated "
|
| 420 |
+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
| 421 |
+
f"`{self.__class__.__name__}.from_config(config)` methods."
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
@classmethod
|
| 425 |
+
def from_config(cls, config, **kwargs):
|
| 426 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 427 |
+
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
|
| 428 |
+
has_local_code = type(config) in cls._model_mapping
|
| 429 |
+
if has_remote_code:
|
| 430 |
+
class_ref = config.auto_map[cls.__name__]
|
| 431 |
+
if "--" in class_ref:
|
| 432 |
+
upstream_repo = class_ref.split("--")[0]
|
| 433 |
+
else:
|
| 434 |
+
upstream_repo = None
|
| 435 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 436 |
+
trust_remote_code, config._name_or_path, has_local_code, has_remote_code, upstream_repo=upstream_repo
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
if has_remote_code and trust_remote_code:
|
| 440 |
+
if "--" in class_ref:
|
| 441 |
+
repo_id, class_ref = class_ref.split("--")
|
| 442 |
+
else:
|
| 443 |
+
repo_id = config.name_or_path
|
| 444 |
+
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
|
| 445 |
+
# This block handles the case where the user is loading a model with `trust_remote_code=True`
|
| 446 |
+
# but a library model exists with the same name. We don't want to override the autoclass
|
| 447 |
+
# mappings in this case, or all future loads of that model will be the remote code model.
|
| 448 |
+
if not has_local_code:
|
| 449 |
+
cls.register(config.__class__, model_class, exist_ok=True)
|
| 450 |
+
model_class.register_for_auto_class(auto_class=cls)
|
| 451 |
+
_ = kwargs.pop("code_revision", None)
|
| 452 |
+
model_class = add_generation_mixin_to_remote_model(model_class)
|
| 453 |
+
return model_class._from_config(config, **kwargs)
|
| 454 |
+
elif type(config) in cls._model_mapping:
|
| 455 |
+
model_class = _get_model_class(config, cls._model_mapping)
|
| 456 |
+
return model_class._from_config(config, **kwargs)
|
| 457 |
+
|
| 458 |
+
raise ValueError(
|
| 459 |
+
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
| 460 |
+
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping)}."
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
@classmethod
|
| 464 |
+
def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
|
| 465 |
+
"""Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses."""
|
| 466 |
+
return config
|
| 467 |
+
|
| 468 |
+
@classmethod
|
| 469 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], *model_args, **kwargs):
|
| 470 |
+
config = kwargs.pop("config", None)
|
| 471 |
+
trust_remote_code = kwargs.get("trust_remote_code")
|
| 472 |
+
kwargs["_from_auto"] = True
|
| 473 |
+
hub_kwargs_names = [
|
| 474 |
+
"cache_dir",
|
| 475 |
+
"force_download",
|
| 476 |
+
"local_files_only",
|
| 477 |
+
"proxies",
|
| 478 |
+
"resume_download",
|
| 479 |
+
"revision",
|
| 480 |
+
"subfolder",
|
| 481 |
+
"use_auth_token",
|
| 482 |
+
"token",
|
| 483 |
+
]
|
| 484 |
+
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
| 485 |
+
code_revision = kwargs.pop("code_revision", None)
|
| 486 |
+
commit_hash = kwargs.pop("_commit_hash", None)
|
| 487 |
+
adapter_kwargs = kwargs.pop("adapter_kwargs", None)
|
| 488 |
+
|
| 489 |
+
token = hub_kwargs.pop("token", None)
|
| 490 |
+
use_auth_token = hub_kwargs.pop("use_auth_token", None)
|
| 491 |
+
if use_auth_token is not None:
|
| 492 |
+
warnings.warn(
|
| 493 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 494 |
+
FutureWarning,
|
| 495 |
+
)
|
| 496 |
+
if token is not None:
|
| 497 |
+
raise ValueError(
|
| 498 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 499 |
+
)
|
| 500 |
+
token = use_auth_token
|
| 501 |
+
|
| 502 |
+
if token is not None:
|
| 503 |
+
hub_kwargs["token"] = token
|
| 504 |
+
|
| 505 |
+
if commit_hash is None:
|
| 506 |
+
if not isinstance(config, PretrainedConfig):
|
| 507 |
+
# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
|
| 508 |
+
resolved_config_file = cached_file(
|
| 509 |
+
pretrained_model_name_or_path,
|
| 510 |
+
CONFIG_NAME,
|
| 511 |
+
_raise_exceptions_for_gated_repo=False,
|
| 512 |
+
_raise_exceptions_for_missing_entries=False,
|
| 513 |
+
_raise_exceptions_for_connection_errors=False,
|
| 514 |
+
**hub_kwargs,
|
| 515 |
+
)
|
| 516 |
+
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
| 517 |
+
else:
|
| 518 |
+
commit_hash = getattr(config, "_commit_hash", None)
|
| 519 |
+
|
| 520 |
+
if is_peft_available():
|
| 521 |
+
if adapter_kwargs is None:
|
| 522 |
+
adapter_kwargs = {}
|
| 523 |
+
if token is not None:
|
| 524 |
+
adapter_kwargs["token"] = token
|
| 525 |
+
|
| 526 |
+
maybe_adapter_path = find_adapter_config_file(
|
| 527 |
+
pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
if maybe_adapter_path is not None:
|
| 531 |
+
with open(maybe_adapter_path, "r", encoding="utf-8") as f:
|
| 532 |
+
adapter_config = json.load(f)
|
| 533 |
+
|
| 534 |
+
adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path
|
| 535 |
+
pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]
|
| 536 |
+
|
| 537 |
+
if not isinstance(config, PretrainedConfig):
|
| 538 |
+
kwargs_orig = copy.deepcopy(kwargs)
|
| 539 |
+
# ensure not to pollute the config object with dtype="auto" - since it's
|
| 540 |
+
# meaningless in the context of the config object - torch.dtype values are acceptable
|
| 541 |
+
if kwargs.get("torch_dtype") == "auto":
|
| 542 |
+
_ = kwargs.pop("torch_dtype")
|
| 543 |
+
if kwargs.get("dtype") == "auto":
|
| 544 |
+
_ = kwargs.pop("dtype")
|
| 545 |
+
# to not overwrite the quantization_config if config has a quantization_config
|
| 546 |
+
if kwargs.get("quantization_config") is not None:
|
| 547 |
+
_ = kwargs.pop("quantization_config")
|
| 548 |
+
|
| 549 |
+
config, kwargs = AutoConfig.from_pretrained(
|
| 550 |
+
pretrained_model_name_or_path,
|
| 551 |
+
return_unused_kwargs=True,
|
| 552 |
+
code_revision=code_revision,
|
| 553 |
+
_commit_hash=commit_hash,
|
| 554 |
+
**hub_kwargs,
|
| 555 |
+
**kwargs,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
# if torch_dtype=auto was passed here, ensure to pass it on
|
| 559 |
+
if kwargs_orig.get("torch_dtype", None) == "auto":
|
| 560 |
+
kwargs["torch_dtype"] = "auto"
|
| 561 |
+
if kwargs_orig.get("dtype", None) == "auto":
|
| 562 |
+
kwargs["dtype"] = "auto"
|
| 563 |
+
if kwargs_orig.get("quantization_config", None) is not None:
|
| 564 |
+
kwargs["quantization_config"] = kwargs_orig["quantization_config"]
|
| 565 |
+
|
| 566 |
+
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
|
| 567 |
+
has_local_code = type(config) in cls._model_mapping
|
| 568 |
+
upstream_repo = None
|
| 569 |
+
if has_remote_code:
|
| 570 |
+
class_ref = config.auto_map[cls.__name__]
|
| 571 |
+
if "--" in class_ref:
|
| 572 |
+
upstream_repo = class_ref.split("--")[0]
|
| 573 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 574 |
+
trust_remote_code,
|
| 575 |
+
pretrained_model_name_or_path,
|
| 576 |
+
has_local_code,
|
| 577 |
+
has_remote_code,
|
| 578 |
+
upstream_repo=upstream_repo,
|
| 579 |
+
)
|
| 580 |
+
kwargs["trust_remote_code"] = trust_remote_code
|
| 581 |
+
|
| 582 |
+
# Set the adapter kwargs
|
| 583 |
+
kwargs["adapter_kwargs"] = adapter_kwargs
|
| 584 |
+
|
| 585 |
+
if has_remote_code and trust_remote_code:
|
| 586 |
+
model_class = get_class_from_dynamic_module(
|
| 587 |
+
class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs
|
| 588 |
+
)
|
| 589 |
+
_ = hub_kwargs.pop("code_revision", None)
|
| 590 |
+
# This block handles the case where the user is loading a model with `trust_remote_code=True`
|
| 591 |
+
# but a library model exists with the same name. We don't want to override the autoclass
|
| 592 |
+
# mappings in this case, or all future loads of that model will be the remote code model.
|
| 593 |
+
if not has_local_code:
|
| 594 |
+
cls.register(config.__class__, model_class, exist_ok=True)
|
| 595 |
+
model_class.register_for_auto_class(auto_class=cls)
|
| 596 |
+
model_class = add_generation_mixin_to_remote_model(model_class)
|
| 597 |
+
return model_class.from_pretrained(
|
| 598 |
+
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
|
| 599 |
+
)
|
| 600 |
+
elif type(config) in cls._model_mapping:
|
| 601 |
+
model_class = _get_model_class(config, cls._model_mapping)
|
| 602 |
+
if model_class.config_class == config.sub_configs.get("text_config", None):
|
| 603 |
+
config = config.get_text_config()
|
| 604 |
+
return model_class.from_pretrained(
|
| 605 |
+
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
|
| 606 |
+
)
|
| 607 |
+
raise ValueError(
|
| 608 |
+
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
| 609 |
+
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping)}."
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
@classmethod
|
| 613 |
+
def register(cls, config_class, model_class, exist_ok=False) -> None:
|
| 614 |
+
"""
|
| 615 |
+
Register a new model for this class.
|
| 616 |
+
|
| 617 |
+
Args:
|
| 618 |
+
config_class ([`PretrainedConfig`]):
|
| 619 |
+
The configuration corresponding to the model to register.
|
| 620 |
+
model_class ([`PreTrainedModel`]):
|
| 621 |
+
The model to register.
|
| 622 |
+
"""
|
| 623 |
+
if hasattr(model_class, "config_class") and model_class.config_class.__name__ != config_class.__name__:
|
| 624 |
+
raise ValueError(
|
| 625 |
+
"The model class you are passing has a `config_class` attribute that is not consistent with the "
|
| 626 |
+
f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix "
|
| 627 |
+
"one of those so they match!"
|
| 628 |
+
)
|
| 629 |
+
cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok)
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
class _BaseAutoBackboneClass(_BaseAutoModelClass):
|
| 633 |
+
# Base class for auto backbone models.
|
| 634 |
+
_model_mapping = None
|
| 635 |
+
|
| 636 |
+
@classmethod
|
| 637 |
+
def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 638 |
+
requires_backends(cls, ["vision", "timm"])
|
| 639 |
+
from ...models.timm_backbone import TimmBackboneConfig
|
| 640 |
+
|
| 641 |
+
config = kwargs.pop("config", TimmBackboneConfig())
|
| 642 |
+
|
| 643 |
+
if kwargs.get("out_features") is not None:
|
| 644 |
+
raise ValueError("Cannot specify `out_features` for timm backbones")
|
| 645 |
+
|
| 646 |
+
if kwargs.get("output_loading_info", False):
|
| 647 |
+
raise ValueError("Cannot specify `output_loading_info=True` when loading from timm")
|
| 648 |
+
|
| 649 |
+
num_channels = kwargs.pop("num_channels", config.num_channels)
|
| 650 |
+
features_only = kwargs.pop("features_only", config.features_only)
|
| 651 |
+
use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
|
| 652 |
+
out_indices = kwargs.pop("out_indices", config.out_indices)
|
| 653 |
+
config = TimmBackboneConfig(
|
| 654 |
+
backbone=pretrained_model_name_or_path,
|
| 655 |
+
num_channels=num_channels,
|
| 656 |
+
features_only=features_only,
|
| 657 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 658 |
+
out_indices=out_indices,
|
| 659 |
+
)
|
| 660 |
+
return super().from_config(config, **kwargs)
|
| 661 |
+
|
| 662 |
+
@classmethod
|
| 663 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 664 |
+
use_timm_backbone = kwargs.pop("use_timm_backbone", False)
|
| 665 |
+
if use_timm_backbone:
|
| 666 |
+
return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
| 667 |
+
|
| 668 |
+
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def insert_head_doc(docstring, head_doc: str = ""):
|
| 672 |
+
if len(head_doc) > 0:
|
| 673 |
+
return docstring.replace(
|
| 674 |
+
"one of the model classes of the library ",
|
| 675 |
+
f"one of the model classes of the library (with a {head_doc} head) ",
|
| 676 |
+
)
|
| 677 |
+
return docstring.replace(
|
| 678 |
+
"one of the model classes of the library ", "one of the base model classes of the library "
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def auto_class_update(cls, checkpoint_for_example: str = "google-bert/bert-base-cased", head_doc: str = ""):
|
| 683 |
+
# Create a new class with the right name from the base class
|
| 684 |
+
model_mapping = cls._model_mapping
|
| 685 |
+
name = cls.__name__
|
| 686 |
+
class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
|
| 687 |
+
cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
|
| 688 |
+
|
| 689 |
+
# Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't
|
| 690 |
+
# have a specific docstrings for them.
|
| 691 |
+
from_config = copy_func(_BaseAutoModelClass.from_config)
|
| 692 |
+
from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc)
|
| 693 |
+
from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
|
| 694 |
+
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
|
| 695 |
+
from_config.__doc__ = from_config_docstring
|
| 696 |
+
from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
|
| 697 |
+
cls.from_config = classmethod(from_config)
|
| 698 |
+
|
| 699 |
+
if name.startswith("TF"):
|
| 700 |
+
from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
|
| 701 |
+
elif name.startswith("Flax"):
|
| 702 |
+
from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING
|
| 703 |
+
else:
|
| 704 |
+
from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING
|
| 705 |
+
from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained)
|
| 706 |
+
from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc)
|
| 707 |
+
from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name)
|
| 708 |
+
from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
|
| 709 |
+
shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
|
| 710 |
+
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
|
| 711 |
+
from_pretrained.__doc__ = from_pretrained_docstring
|
| 712 |
+
from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
|
| 713 |
+
cls.from_pretrained = classmethod(from_pretrained)
|
| 714 |
+
return cls
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
def get_values(model_mapping):
|
| 718 |
+
result = []
|
| 719 |
+
for model in model_mapping.values():
|
| 720 |
+
if isinstance(model, (list, tuple)):
|
| 721 |
+
result += list(model)
|
| 722 |
+
else:
|
| 723 |
+
result.append(model)
|
| 724 |
+
|
| 725 |
+
return result
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def getattribute_from_module(module, attr):
|
| 729 |
+
if attr is None:
|
| 730 |
+
return None
|
| 731 |
+
if isinstance(attr, tuple):
|
| 732 |
+
return tuple(getattribute_from_module(module, a) for a in attr)
|
| 733 |
+
if hasattr(module, attr):
|
| 734 |
+
return getattr(module, attr)
|
| 735 |
+
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
|
| 736 |
+
# object at the top level.
|
| 737 |
+
transformers_module = importlib.import_module("transformers")
|
| 738 |
+
|
| 739 |
+
if module != transformers_module:
|
| 740 |
+
try:
|
| 741 |
+
return getattribute_from_module(transformers_module, attr)
|
| 742 |
+
except ValueError:
|
| 743 |
+
raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!")
|
| 744 |
+
else:
|
| 745 |
+
raise ValueError(f"Could not find {attr} in {transformers_module}!")
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
def add_generation_mixin_to_remote_model(model_class):
|
| 749 |
+
"""
|
| 750 |
+
Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model.
|
| 751 |
+
|
| 752 |
+
This function is used for backwards compatibility purposes: in v4.45, we've started a deprecation cycle to make
|
| 753 |
+
`PreTrainedModel` stop inheriting from `GenerationMixin`. Without this function, older models dynamically loaded
|
| 754 |
+
from the Hub may not have the `generate` method after we remove the inheritance.
|
| 755 |
+
"""
|
| 756 |
+
# 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing
|
| 757 |
+
if "torch.nn.modules.module.Module" not in str(model_class.__mro__):
|
| 758 |
+
return model_class
|
| 759 |
+
|
| 760 |
+
# 2. If it already **directly** inherits from GenerationMixin, do nothing
|
| 761 |
+
if "GenerationMixin" in str(model_class.__bases__):
|
| 762 |
+
return model_class
|
| 763 |
+
|
| 764 |
+
# 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
|
| 765 |
+
# `prepare_inputs_for_generation` method.
|
| 766 |
+
has_custom_generate_in_class = hasattr(model_class, "generate") and "GenerationMixin" not in str(
|
| 767 |
+
getattr(model_class, "generate")
|
| 768 |
+
)
|
| 769 |
+
has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str(
|
| 770 |
+
getattr(model_class, "prepare_inputs_for_generation")
|
| 771 |
+
)
|
| 772 |
+
if has_custom_generate_in_class or has_custom_prepare_inputs:
|
| 773 |
+
model_class_with_generation_mixin = type(
|
| 774 |
+
model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}
|
| 775 |
+
)
|
| 776 |
+
return model_class_with_generation_mixin
|
| 777 |
+
return model_class
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
class _LazyAutoMapping(OrderedDict[type[PretrainedConfig], _LazyAutoMappingValue]):
|
| 781 |
+
"""
|
| 782 |
+
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
|
| 783 |
+
|
| 784 |
+
Args:
|
| 785 |
+
- config_mapping: The map model type to config class
|
| 786 |
+
- model_mapping: The map model type to model (or tokenizer) class
|
| 787 |
+
"""
|
| 788 |
+
|
| 789 |
+
def __init__(self, config_mapping, model_mapping) -> None:
|
| 790 |
+
self._config_mapping = config_mapping
|
| 791 |
+
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
|
| 792 |
+
self._model_mapping = model_mapping
|
| 793 |
+
self._model_mapping._model_mapping = self
|
| 794 |
+
self._extra_content = {}
|
| 795 |
+
self._modules = {}
|
| 796 |
+
|
| 797 |
+
def __len__(self) -> int:
|
| 798 |
+
common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys())
|
| 799 |
+
return len(common_keys) + len(self._extra_content)
|
| 800 |
+
|
| 801 |
+
def __getitem__(self, key: type[PretrainedConfig]) -> _LazyAutoMappingValue:
|
| 802 |
+
if key in self._extra_content:
|
| 803 |
+
return self._extra_content[key]
|
| 804 |
+
model_type = self._reverse_config_mapping[key.__name__]
|
| 805 |
+
if model_type in self._model_mapping:
|
| 806 |
+
model_name = self._model_mapping[model_type]
|
| 807 |
+
return self._load_attr_from_module(model_type, model_name)
|
| 808 |
+
|
| 809 |
+
# Maybe there was several model types associated with this config.
|
| 810 |
+
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
|
| 811 |
+
for mtype in model_types:
|
| 812 |
+
if mtype in self._model_mapping:
|
| 813 |
+
model_name = self._model_mapping[mtype]
|
| 814 |
+
return self._load_attr_from_module(mtype, model_name)
|
| 815 |
+
raise KeyError(key)
|
| 816 |
+
|
| 817 |
+
def _load_attr_from_module(self, model_type, attr):
|
| 818 |
+
module_name = model_type_to_module_name(model_type)
|
| 819 |
+
if module_name not in self._modules:
|
| 820 |
+
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
|
| 821 |
+
return getattribute_from_module(self._modules[module_name], attr)
|
| 822 |
+
|
| 823 |
+
def keys(self) -> list[type[PretrainedConfig]]:
|
| 824 |
+
mapping_keys = [
|
| 825 |
+
self._load_attr_from_module(key, name)
|
| 826 |
+
for key, name in self._config_mapping.items()
|
| 827 |
+
if key in self._model_mapping
|
| 828 |
+
]
|
| 829 |
+
return mapping_keys + list(self._extra_content.keys())
|
| 830 |
+
|
| 831 |
+
def get(self, key: type[PretrainedConfig], default: _T) -> Union[_LazyAutoMappingValue, _T]:
|
| 832 |
+
try:
|
| 833 |
+
return self.__getitem__(key)
|
| 834 |
+
except KeyError:
|
| 835 |
+
return default
|
| 836 |
+
|
| 837 |
+
def __bool__(self) -> bool:
|
| 838 |
+
return bool(self.keys())
|
| 839 |
+
|
| 840 |
+
def values(self) -> list[_LazyAutoMappingValue]:
|
| 841 |
+
mapping_values = [
|
| 842 |
+
self._load_attr_from_module(key, name)
|
| 843 |
+
for key, name in self._model_mapping.items()
|
| 844 |
+
if key in self._config_mapping
|
| 845 |
+
]
|
| 846 |
+
return mapping_values + list(self._extra_content.values())
|
| 847 |
+
|
| 848 |
+
def items(self) -> list[tuple[type[PretrainedConfig], _LazyAutoMappingValue]]:
|
| 849 |
+
mapping_items = [
|
| 850 |
+
(
|
| 851 |
+
self._load_attr_from_module(key, self._config_mapping[key]),
|
| 852 |
+
self._load_attr_from_module(key, self._model_mapping[key]),
|
| 853 |
+
)
|
| 854 |
+
for key in self._model_mapping
|
| 855 |
+
if key in self._config_mapping
|
| 856 |
+
]
|
| 857 |
+
return mapping_items + list(self._extra_content.items())
|
| 858 |
+
|
| 859 |
+
def __iter__(self) -> Iterator[type[PretrainedConfig]]:
|
| 860 |
+
return iter(self.keys())
|
| 861 |
+
|
| 862 |
+
def __contains__(self, item: type) -> bool:
|
| 863 |
+
if item in self._extra_content:
|
| 864 |
+
return True
|
| 865 |
+
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
|
| 866 |
+
return False
|
| 867 |
+
model_type = self._reverse_config_mapping[item.__name__]
|
| 868 |
+
return model_type in self._model_mapping
|
| 869 |
+
|
| 870 |
+
def register(self, key: type[PretrainedConfig], value: _LazyAutoMappingValue, exist_ok=False) -> None:
|
| 871 |
+
"""
|
| 872 |
+
Register a new model in this mapping.
|
| 873 |
+
"""
|
| 874 |
+
if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
|
| 875 |
+
model_type = self._reverse_config_mapping[key.__name__]
|
| 876 |
+
if model_type in self._model_mapping and not exist_ok:
|
| 877 |
+
raise ValueError(f"'{key}' is already used by a Transformers model.")
|
| 878 |
+
|
| 879 |
+
self._extra_content[key] = value
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
__all__ = ["get_values"]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/configuration_auto.py
ADDED
|
@@ -0,0 +1,1404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Auto Config class."""
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import os
|
| 19 |
+
import re
|
| 20 |
+
import warnings
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
from collections.abc import Callable, Iterator, KeysView, ValuesView
|
| 23 |
+
from typing import Any, TypeVar, Union
|
| 24 |
+
|
| 25 |
+
from ...configuration_utils import PretrainedConfig
|
| 26 |
+
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
| 27 |
+
from ...utils import CONFIG_NAME, logging
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
_CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
| 37 |
+
[
|
| 38 |
+
# Add configs here
|
| 39 |
+
("aimv2", "Aimv2Config"),
|
| 40 |
+
("aimv2_vision_model", "Aimv2VisionConfig"),
|
| 41 |
+
("albert", "AlbertConfig"),
|
| 42 |
+
("align", "AlignConfig"),
|
| 43 |
+
("altclip", "AltCLIPConfig"),
|
| 44 |
+
("apertus", "ApertusConfig"),
|
| 45 |
+
("arcee", "ArceeConfig"),
|
| 46 |
+
("aria", "AriaConfig"),
|
| 47 |
+
("aria_text", "AriaTextConfig"),
|
| 48 |
+
("audio-spectrogram-transformer", "ASTConfig"),
|
| 49 |
+
("autoformer", "AutoformerConfig"),
|
| 50 |
+
("aya_vision", "AyaVisionConfig"),
|
| 51 |
+
("bamba", "BambaConfig"),
|
| 52 |
+
("bark", "BarkConfig"),
|
| 53 |
+
("bart", "BartConfig"),
|
| 54 |
+
("beit", "BeitConfig"),
|
| 55 |
+
("bert", "BertConfig"),
|
| 56 |
+
("bert-generation", "BertGenerationConfig"),
|
| 57 |
+
("big_bird", "BigBirdConfig"),
|
| 58 |
+
("bigbird_pegasus", "BigBirdPegasusConfig"),
|
| 59 |
+
("biogpt", "BioGptConfig"),
|
| 60 |
+
("bit", "BitConfig"),
|
| 61 |
+
("bitnet", "BitNetConfig"),
|
| 62 |
+
("blenderbot", "BlenderbotConfig"),
|
| 63 |
+
("blenderbot-small", "BlenderbotSmallConfig"),
|
| 64 |
+
("blip", "BlipConfig"),
|
| 65 |
+
("blip-2", "Blip2Config"),
|
| 66 |
+
("blip_2_qformer", "Blip2QFormerConfig"),
|
| 67 |
+
("bloom", "BloomConfig"),
|
| 68 |
+
("blt", "BltConfig"),
|
| 69 |
+
("bridgetower", "BridgeTowerConfig"),
|
| 70 |
+
("bros", "BrosConfig"),
|
| 71 |
+
("camembert", "CamembertConfig"),
|
| 72 |
+
("canine", "CanineConfig"),
|
| 73 |
+
("chameleon", "ChameleonConfig"),
|
| 74 |
+
("chinese_clip", "ChineseCLIPConfig"),
|
| 75 |
+
("chinese_clip_vision_model", "ChineseCLIPVisionConfig"),
|
| 76 |
+
("clap", "ClapConfig"),
|
| 77 |
+
("clip", "CLIPConfig"),
|
| 78 |
+
("clip_text_model", "CLIPTextConfig"),
|
| 79 |
+
("clip_vision_model", "CLIPVisionConfig"),
|
| 80 |
+
("clipseg", "CLIPSegConfig"),
|
| 81 |
+
("clvp", "ClvpConfig"),
|
| 82 |
+
("code_llama", "LlamaConfig"),
|
| 83 |
+
("codegen", "CodeGenConfig"),
|
| 84 |
+
("cohere", "CohereConfig"),
|
| 85 |
+
("cohere2", "Cohere2Config"),
|
| 86 |
+
("cohere2_vision", "Cohere2VisionConfig"),
|
| 87 |
+
("colpali", "ColPaliConfig"),
|
| 88 |
+
("colqwen2", "ColQwen2Config"),
|
| 89 |
+
("conditional_detr", "ConditionalDetrConfig"),
|
| 90 |
+
("convbert", "ConvBertConfig"),
|
| 91 |
+
("convnext", "ConvNextConfig"),
|
| 92 |
+
("convnextv2", "ConvNextV2Config"),
|
| 93 |
+
("cpmant", "CpmAntConfig"),
|
| 94 |
+
("csm", "CsmConfig"),
|
| 95 |
+
("ctrl", "CTRLConfig"),
|
| 96 |
+
("cvt", "CvtConfig"),
|
| 97 |
+
("d_fine", "DFineConfig"),
|
| 98 |
+
("dab-detr", "DabDetrConfig"),
|
| 99 |
+
("dac", "DacConfig"),
|
| 100 |
+
("data2vec-audio", "Data2VecAudioConfig"),
|
| 101 |
+
("data2vec-text", "Data2VecTextConfig"),
|
| 102 |
+
("data2vec-vision", "Data2VecVisionConfig"),
|
| 103 |
+
("dbrx", "DbrxConfig"),
|
| 104 |
+
("deberta", "DebertaConfig"),
|
| 105 |
+
("deberta-v2", "DebertaV2Config"),
|
| 106 |
+
("decision_transformer", "DecisionTransformerConfig"),
|
| 107 |
+
("deepseek_v2", "DeepseekV2Config"),
|
| 108 |
+
("deepseek_v3", "DeepseekV3Config"),
|
| 109 |
+
("deepseek_vl", "DeepseekVLConfig"),
|
| 110 |
+
("deepseek_vl_hybrid", "DeepseekVLHybridConfig"),
|
| 111 |
+
("deformable_detr", "DeformableDetrConfig"),
|
| 112 |
+
("deit", "DeiTConfig"),
|
| 113 |
+
("depth_anything", "DepthAnythingConfig"),
|
| 114 |
+
("depth_pro", "DepthProConfig"),
|
| 115 |
+
("deta", "DetaConfig"),
|
| 116 |
+
("detr", "DetrConfig"),
|
| 117 |
+
("dia", "DiaConfig"),
|
| 118 |
+
("diffllama", "DiffLlamaConfig"),
|
| 119 |
+
("dinat", "DinatConfig"),
|
| 120 |
+
("dinov2", "Dinov2Config"),
|
| 121 |
+
("dinov2_with_registers", "Dinov2WithRegistersConfig"),
|
| 122 |
+
("dinov3_convnext", "DINOv3ConvNextConfig"),
|
| 123 |
+
("dinov3_vit", "DINOv3ViTConfig"),
|
| 124 |
+
("distilbert", "DistilBertConfig"),
|
| 125 |
+
("doge", "DogeConfig"),
|
| 126 |
+
("donut-swin", "DonutSwinConfig"),
|
| 127 |
+
("dots1", "Dots1Config"),
|
| 128 |
+
("dpr", "DPRConfig"),
|
| 129 |
+
("dpt", "DPTConfig"),
|
| 130 |
+
("edgetam", "EdgeTamConfig"),
|
| 131 |
+
("edgetam_video", "EdgeTamVideoConfig"),
|
| 132 |
+
("edgetam_vision_model", "EdgeTamVisionConfig"),
|
| 133 |
+
("efficientformer", "EfficientFormerConfig"),
|
| 134 |
+
("efficientloftr", "EfficientLoFTRConfig"),
|
| 135 |
+
("efficientnet", "EfficientNetConfig"),
|
| 136 |
+
("electra", "ElectraConfig"),
|
| 137 |
+
("emu3", "Emu3Config"),
|
| 138 |
+
("encodec", "EncodecConfig"),
|
| 139 |
+
("encoder-decoder", "EncoderDecoderConfig"),
|
| 140 |
+
("eomt", "EomtConfig"),
|
| 141 |
+
("ernie", "ErnieConfig"),
|
| 142 |
+
("ernie4_5", "Ernie4_5Config"),
|
| 143 |
+
("ernie4_5_moe", "Ernie4_5_MoeConfig"),
|
| 144 |
+
("ernie_m", "ErnieMConfig"),
|
| 145 |
+
("esm", "EsmConfig"),
|
| 146 |
+
("evolla", "EvollaConfig"),
|
| 147 |
+
("exaone4", "Exaone4Config"),
|
| 148 |
+
("falcon", "FalconConfig"),
|
| 149 |
+
("falcon_h1", "FalconH1Config"),
|
| 150 |
+
("falcon_mamba", "FalconMambaConfig"),
|
| 151 |
+
("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
|
| 152 |
+
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGanConfig"),
|
| 153 |
+
("flaubert", "FlaubertConfig"),
|
| 154 |
+
("flava", "FlavaConfig"),
|
| 155 |
+
("flex_olmo", "FlexOlmoConfig"),
|
| 156 |
+
("florence2", "Florence2Config"),
|
| 157 |
+
("fnet", "FNetConfig"),
|
| 158 |
+
("focalnet", "FocalNetConfig"),
|
| 159 |
+
("fsmt", "FSMTConfig"),
|
| 160 |
+
("funnel", "FunnelConfig"),
|
| 161 |
+
("fuyu", "FuyuConfig"),
|
| 162 |
+
("gemma", "GemmaConfig"),
|
| 163 |
+
("gemma2", "Gemma2Config"),
|
| 164 |
+
("gemma3", "Gemma3Config"),
|
| 165 |
+
("gemma3_text", "Gemma3TextConfig"),
|
| 166 |
+
("gemma3n", "Gemma3nConfig"),
|
| 167 |
+
("gemma3n_audio", "Gemma3nAudioConfig"),
|
| 168 |
+
("gemma3n_text", "Gemma3nTextConfig"),
|
| 169 |
+
("gemma3n_vision", "Gemma3nVisionConfig"),
|
| 170 |
+
("git", "GitConfig"),
|
| 171 |
+
("glm", "GlmConfig"),
|
| 172 |
+
("glm4", "Glm4Config"),
|
| 173 |
+
("glm4_moe", "Glm4MoeConfig"),
|
| 174 |
+
("glm4v", "Glm4vConfig"),
|
| 175 |
+
("glm4v_moe", "Glm4vMoeConfig"),
|
| 176 |
+
("glm4v_moe_text", "Glm4vMoeTextConfig"),
|
| 177 |
+
("glm4v_text", "Glm4vTextConfig"),
|
| 178 |
+
("glpn", "GLPNConfig"),
|
| 179 |
+
("got_ocr2", "GotOcr2Config"),
|
| 180 |
+
("gpt-sw3", "GPT2Config"),
|
| 181 |
+
("gpt2", "GPT2Config"),
|
| 182 |
+
("gpt_bigcode", "GPTBigCodeConfig"),
|
| 183 |
+
("gpt_neo", "GPTNeoConfig"),
|
| 184 |
+
("gpt_neox", "GPTNeoXConfig"),
|
| 185 |
+
("gpt_neox_japanese", "GPTNeoXJapaneseConfig"),
|
| 186 |
+
("gpt_oss", "GptOssConfig"),
|
| 187 |
+
("gptj", "GPTJConfig"),
|
| 188 |
+
("gptsan-japanese", "GPTSanJapaneseConfig"),
|
| 189 |
+
("granite", "GraniteConfig"),
|
| 190 |
+
("granite_speech", "GraniteSpeechConfig"),
|
| 191 |
+
("granitemoe", "GraniteMoeConfig"),
|
| 192 |
+
("granitemoehybrid", "GraniteMoeHybridConfig"),
|
| 193 |
+
("granitemoeshared", "GraniteMoeSharedConfig"),
|
| 194 |
+
("granitevision", "LlavaNextConfig"),
|
| 195 |
+
("graphormer", "GraphormerConfig"),
|
| 196 |
+
("grounding-dino", "GroundingDinoConfig"),
|
| 197 |
+
("groupvit", "GroupViTConfig"),
|
| 198 |
+
("helium", "HeliumConfig"),
|
| 199 |
+
("hgnet_v2", "HGNetV2Config"),
|
| 200 |
+
("hiera", "HieraConfig"),
|
| 201 |
+
("hubert", "HubertConfig"),
|
| 202 |
+
("hunyuan_v1_dense", "HunYuanDenseV1Config"),
|
| 203 |
+
("hunyuan_v1_moe", "HunYuanMoEV1Config"),
|
| 204 |
+
("ibert", "IBertConfig"),
|
| 205 |
+
("idefics", "IdeficsConfig"),
|
| 206 |
+
("idefics2", "Idefics2Config"),
|
| 207 |
+
("idefics3", "Idefics3Config"),
|
| 208 |
+
("idefics3_vision", "Idefics3VisionConfig"),
|
| 209 |
+
("ijepa", "IJepaConfig"),
|
| 210 |
+
("imagegpt", "ImageGPTConfig"),
|
| 211 |
+
("informer", "InformerConfig"),
|
| 212 |
+
("instructblip", "InstructBlipConfig"),
|
| 213 |
+
("instructblipvideo", "InstructBlipVideoConfig"),
|
| 214 |
+
("internvl", "InternVLConfig"),
|
| 215 |
+
("internvl_vision", "InternVLVisionConfig"),
|
| 216 |
+
("jamba", "JambaConfig"),
|
| 217 |
+
("janus", "JanusConfig"),
|
| 218 |
+
("jetmoe", "JetMoeConfig"),
|
| 219 |
+
("jukebox", "JukeboxConfig"),
|
| 220 |
+
("kosmos-2", "Kosmos2Config"),
|
| 221 |
+
("kosmos-2.5", "Kosmos2_5Config"),
|
| 222 |
+
("kyutai_speech_to_text", "KyutaiSpeechToTextConfig"),
|
| 223 |
+
("layoutlm", "LayoutLMConfig"),
|
| 224 |
+
("layoutlmv2", "LayoutLMv2Config"),
|
| 225 |
+
("layoutlmv3", "LayoutLMv3Config"),
|
| 226 |
+
("led", "LEDConfig"),
|
| 227 |
+
("levit", "LevitConfig"),
|
| 228 |
+
("lfm2", "Lfm2Config"),
|
| 229 |
+
("lfm2_vl", "Lfm2VlConfig"),
|
| 230 |
+
("lightglue", "LightGlueConfig"),
|
| 231 |
+
("lilt", "LiltConfig"),
|
| 232 |
+
("llama", "LlamaConfig"),
|
| 233 |
+
("llama4", "Llama4Config"),
|
| 234 |
+
("llama4_text", "Llama4TextConfig"),
|
| 235 |
+
("llava", "LlavaConfig"),
|
| 236 |
+
("llava_next", "LlavaNextConfig"),
|
| 237 |
+
("llava_next_video", "LlavaNextVideoConfig"),
|
| 238 |
+
("llava_onevision", "LlavaOnevisionConfig"),
|
| 239 |
+
("longcat_flash", "LongcatFlashConfig"),
|
| 240 |
+
("longformer", "LongformerConfig"),
|
| 241 |
+
("longt5", "LongT5Config"),
|
| 242 |
+
("luke", "LukeConfig"),
|
| 243 |
+
("lxmert", "LxmertConfig"),
|
| 244 |
+
("m2m_100", "M2M100Config"),
|
| 245 |
+
("mamba", "MambaConfig"),
|
| 246 |
+
("mamba2", "Mamba2Config"),
|
| 247 |
+
("marian", "MarianConfig"),
|
| 248 |
+
("markuplm", "MarkupLMConfig"),
|
| 249 |
+
("mask2former", "Mask2FormerConfig"),
|
| 250 |
+
("maskformer", "MaskFormerConfig"),
|
| 251 |
+
("maskformer-swin", "MaskFormerSwinConfig"),
|
| 252 |
+
("mbart", "MBartConfig"),
|
| 253 |
+
("mctct", "MCTCTConfig"),
|
| 254 |
+
("mega", "MegaConfig"),
|
| 255 |
+
("megatron-bert", "MegatronBertConfig"),
|
| 256 |
+
("metaclip_2", "MetaClip2Config"),
|
| 257 |
+
("mgp-str", "MgpstrConfig"),
|
| 258 |
+
("mimi", "MimiConfig"),
|
| 259 |
+
("minimax", "MiniMaxConfig"),
|
| 260 |
+
("ministral", "MinistralConfig"),
|
| 261 |
+
("mistral", "MistralConfig"),
|
| 262 |
+
("mistral3", "Mistral3Config"),
|
| 263 |
+
("mixtral", "MixtralConfig"),
|
| 264 |
+
("mlcd", "MLCDVisionConfig"),
|
| 265 |
+
("mllama", "MllamaConfig"),
|
| 266 |
+
("mm-grounding-dino", "MMGroundingDinoConfig"),
|
| 267 |
+
("mobilebert", "MobileBertConfig"),
|
| 268 |
+
("mobilenet_v1", "MobileNetV1Config"),
|
| 269 |
+
("mobilenet_v2", "MobileNetV2Config"),
|
| 270 |
+
("mobilevit", "MobileViTConfig"),
|
| 271 |
+
("mobilevitv2", "MobileViTV2Config"),
|
| 272 |
+
("modernbert", "ModernBertConfig"),
|
| 273 |
+
("modernbert-decoder", "ModernBertDecoderConfig"),
|
| 274 |
+
("moonshine", "MoonshineConfig"),
|
| 275 |
+
("moshi", "MoshiConfig"),
|
| 276 |
+
("mpnet", "MPNetConfig"),
|
| 277 |
+
("mpt", "MptConfig"),
|
| 278 |
+
("mra", "MraConfig"),
|
| 279 |
+
("mt5", "MT5Config"),
|
| 280 |
+
("musicgen", "MusicgenConfig"),
|
| 281 |
+
("musicgen_melody", "MusicgenMelodyConfig"),
|
| 282 |
+
("mvp", "MvpConfig"),
|
| 283 |
+
("nat", "NatConfig"),
|
| 284 |
+
("nemotron", "NemotronConfig"),
|
| 285 |
+
("nezha", "NezhaConfig"),
|
| 286 |
+
("nllb-moe", "NllbMoeConfig"),
|
| 287 |
+
("nougat", "VisionEncoderDecoderConfig"),
|
| 288 |
+
("nystromformer", "NystromformerConfig"),
|
| 289 |
+
("olmo", "OlmoConfig"),
|
| 290 |
+
("olmo2", "Olmo2Config"),
|
| 291 |
+
("olmo3", "Olmo3Config"),
|
| 292 |
+
("olmoe", "OlmoeConfig"),
|
| 293 |
+
("omdet-turbo", "OmDetTurboConfig"),
|
| 294 |
+
("oneformer", "OneFormerConfig"),
|
| 295 |
+
("open-llama", "OpenLlamaConfig"),
|
| 296 |
+
("openai-gpt", "OpenAIGPTConfig"),
|
| 297 |
+
("opt", "OPTConfig"),
|
| 298 |
+
("ovis2", "Ovis2Config"),
|
| 299 |
+
("owlv2", "Owlv2Config"),
|
| 300 |
+
("owlvit", "OwlViTConfig"),
|
| 301 |
+
("paligemma", "PaliGemmaConfig"),
|
| 302 |
+
("parakeet_ctc", "ParakeetCTCConfig"),
|
| 303 |
+
("parakeet_encoder", "ParakeetEncoderConfig"),
|
| 304 |
+
("patchtsmixer", "PatchTSMixerConfig"),
|
| 305 |
+
("patchtst", "PatchTSTConfig"),
|
| 306 |
+
("pegasus", "PegasusConfig"),
|
| 307 |
+
("pegasus_x", "PegasusXConfig"),
|
| 308 |
+
("perceiver", "PerceiverConfig"),
|
| 309 |
+
("perception_encoder", "TimmWrapperConfig"),
|
| 310 |
+
("perception_lm", "PerceptionLMConfig"),
|
| 311 |
+
("persimmon", "PersimmonConfig"),
|
| 312 |
+
("phi", "PhiConfig"),
|
| 313 |
+
("phi3", "Phi3Config"),
|
| 314 |
+
("phi4_multimodal", "Phi4MultimodalConfig"),
|
| 315 |
+
("phimoe", "PhimoeConfig"),
|
| 316 |
+
("pix2struct", "Pix2StructConfig"),
|
| 317 |
+
("pixtral", "PixtralVisionConfig"),
|
| 318 |
+
("plbart", "PLBartConfig"),
|
| 319 |
+
("poolformer", "PoolFormerConfig"),
|
| 320 |
+
("pop2piano", "Pop2PianoConfig"),
|
| 321 |
+
("prompt_depth_anything", "PromptDepthAnythingConfig"),
|
| 322 |
+
("prophetnet", "ProphetNetConfig"),
|
| 323 |
+
("pvt", "PvtConfig"),
|
| 324 |
+
("pvt_v2", "PvtV2Config"),
|
| 325 |
+
("qdqbert", "QDQBertConfig"),
|
| 326 |
+
("qwen2", "Qwen2Config"),
|
| 327 |
+
("qwen2_5_omni", "Qwen2_5OmniConfig"),
|
| 328 |
+
("qwen2_5_vl", "Qwen2_5_VLConfig"),
|
| 329 |
+
("qwen2_5_vl_text", "Qwen2_5_VLTextConfig"),
|
| 330 |
+
("qwen2_audio", "Qwen2AudioConfig"),
|
| 331 |
+
("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"),
|
| 332 |
+
("qwen2_moe", "Qwen2MoeConfig"),
|
| 333 |
+
("qwen2_vl", "Qwen2VLConfig"),
|
| 334 |
+
("qwen2_vl_text", "Qwen2VLTextConfig"),
|
| 335 |
+
("qwen3", "Qwen3Config"),
|
| 336 |
+
("qwen3_moe", "Qwen3MoeConfig"),
|
| 337 |
+
("qwen3_next", "Qwen3NextConfig"),
|
| 338 |
+
("qwen3_omni_moe", "Qwen3OmniMoeConfig"),
|
| 339 |
+
("qwen3_vl", "Qwen3VLConfig"),
|
| 340 |
+
("qwen3_vl_moe", "Qwen3VLMoeConfig"),
|
| 341 |
+
("qwen3_vl_moe_text", "Qwen3VLMoeTextConfig"),
|
| 342 |
+
("qwen3_vl_text", "Qwen3VLTextConfig"),
|
| 343 |
+
("rag", "RagConfig"),
|
| 344 |
+
("realm", "RealmConfig"),
|
| 345 |
+
("recurrent_gemma", "RecurrentGemmaConfig"),
|
| 346 |
+
("reformer", "ReformerConfig"),
|
| 347 |
+
("regnet", "RegNetConfig"),
|
| 348 |
+
("rembert", "RemBertConfig"),
|
| 349 |
+
("resnet", "ResNetConfig"),
|
| 350 |
+
("retribert", "RetriBertConfig"),
|
| 351 |
+
("roberta", "RobertaConfig"),
|
| 352 |
+
("roberta-prelayernorm", "RobertaPreLayerNormConfig"),
|
| 353 |
+
("roc_bert", "RoCBertConfig"),
|
| 354 |
+
("roformer", "RoFormerConfig"),
|
| 355 |
+
("rt_detr", "RTDetrConfig"),
|
| 356 |
+
("rt_detr_resnet", "RTDetrResNetConfig"),
|
| 357 |
+
("rt_detr_v2", "RTDetrV2Config"),
|
| 358 |
+
("rwkv", "RwkvConfig"),
|
| 359 |
+
("sam", "SamConfig"),
|
| 360 |
+
("sam2", "Sam2Config"),
|
| 361 |
+
("sam2_hiera_det_model", "Sam2HieraDetConfig"),
|
| 362 |
+
("sam2_video", "Sam2VideoConfig"),
|
| 363 |
+
("sam2_vision_model", "Sam2VisionConfig"),
|
| 364 |
+
("sam_hq", "SamHQConfig"),
|
| 365 |
+
("sam_hq_vision_model", "SamHQVisionConfig"),
|
| 366 |
+
("sam_vision_model", "SamVisionConfig"),
|
| 367 |
+
("seamless_m4t", "SeamlessM4TConfig"),
|
| 368 |
+
("seamless_m4t_v2", "SeamlessM4Tv2Config"),
|
| 369 |
+
("seed_oss", "SeedOssConfig"),
|
| 370 |
+
("segformer", "SegformerConfig"),
|
| 371 |
+
("seggpt", "SegGptConfig"),
|
| 372 |
+
("sew", "SEWConfig"),
|
| 373 |
+
("sew-d", "SEWDConfig"),
|
| 374 |
+
("shieldgemma2", "ShieldGemma2Config"),
|
| 375 |
+
("siglip", "SiglipConfig"),
|
| 376 |
+
("siglip2", "Siglip2Config"),
|
| 377 |
+
("siglip2_vision_model", "Siglip2VisionConfig"),
|
| 378 |
+
("siglip_vision_model", "SiglipVisionConfig"),
|
| 379 |
+
("smollm3", "SmolLM3Config"),
|
| 380 |
+
("smolvlm", "SmolVLMConfig"),
|
| 381 |
+
("smolvlm_vision", "SmolVLMVisionConfig"),
|
| 382 |
+
("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
|
| 383 |
+
("speech_to_text", "Speech2TextConfig"),
|
| 384 |
+
("speech_to_text_2", "Speech2Text2Config"),
|
| 385 |
+
("speecht5", "SpeechT5Config"),
|
| 386 |
+
("splinter", "SplinterConfig"),
|
| 387 |
+
("squeezebert", "SqueezeBertConfig"),
|
| 388 |
+
("stablelm", "StableLmConfig"),
|
| 389 |
+
("starcoder2", "Starcoder2Config"),
|
| 390 |
+
("superglue", "SuperGlueConfig"),
|
| 391 |
+
("superpoint", "SuperPointConfig"),
|
| 392 |
+
("swiftformer", "SwiftFormerConfig"),
|
| 393 |
+
("swin", "SwinConfig"),
|
| 394 |
+
("swin2sr", "Swin2SRConfig"),
|
| 395 |
+
("swinv2", "Swinv2Config"),
|
| 396 |
+
("switch_transformers", "SwitchTransformersConfig"),
|
| 397 |
+
("t5", "T5Config"),
|
| 398 |
+
("t5gemma", "T5GemmaConfig"),
|
| 399 |
+
("table-transformer", "TableTransformerConfig"),
|
| 400 |
+
("tapas", "TapasConfig"),
|
| 401 |
+
("textnet", "TextNetConfig"),
|
| 402 |
+
("time_series_transformer", "TimeSeriesTransformerConfig"),
|
| 403 |
+
("timesfm", "TimesFmConfig"),
|
| 404 |
+
("timesformer", "TimesformerConfig"),
|
| 405 |
+
("timm_backbone", "TimmBackboneConfig"),
|
| 406 |
+
("timm_wrapper", "TimmWrapperConfig"),
|
| 407 |
+
("trajectory_transformer", "TrajectoryTransformerConfig"),
|
| 408 |
+
("transfo-xl", "TransfoXLConfig"),
|
| 409 |
+
("trocr", "TrOCRConfig"),
|
| 410 |
+
("tvlt", "TvltConfig"),
|
| 411 |
+
("tvp", "TvpConfig"),
|
| 412 |
+
("udop", "UdopConfig"),
|
| 413 |
+
("umt5", "UMT5Config"),
|
| 414 |
+
("unispeech", "UniSpeechConfig"),
|
| 415 |
+
("unispeech-sat", "UniSpeechSatConfig"),
|
| 416 |
+
("univnet", "UnivNetConfig"),
|
| 417 |
+
("upernet", "UperNetConfig"),
|
| 418 |
+
("van", "VanConfig"),
|
| 419 |
+
("vaultgemma", "VaultGemmaConfig"),
|
| 420 |
+
("video_llava", "VideoLlavaConfig"),
|
| 421 |
+
("videomae", "VideoMAEConfig"),
|
| 422 |
+
("vilt", "ViltConfig"),
|
| 423 |
+
("vipllava", "VipLlavaConfig"),
|
| 424 |
+
("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
|
| 425 |
+
("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
|
| 426 |
+
("visual_bert", "VisualBertConfig"),
|
| 427 |
+
("vit", "ViTConfig"),
|
| 428 |
+
("vit_hybrid", "ViTHybridConfig"),
|
| 429 |
+
("vit_mae", "ViTMAEConfig"),
|
| 430 |
+
("vit_msn", "ViTMSNConfig"),
|
| 431 |
+
("vitdet", "VitDetConfig"),
|
| 432 |
+
("vitmatte", "VitMatteConfig"),
|
| 433 |
+
("vitpose", "VitPoseConfig"),
|
| 434 |
+
("vitpose_backbone", "VitPoseBackboneConfig"),
|
| 435 |
+
("vits", "VitsConfig"),
|
| 436 |
+
("vivit", "VivitConfig"),
|
| 437 |
+
("vjepa2", "VJEPA2Config"),
|
| 438 |
+
("voxtral", "VoxtralConfig"),
|
| 439 |
+
("voxtral_encoder", "VoxtralEncoderConfig"),
|
| 440 |
+
("wav2vec2", "Wav2Vec2Config"),
|
| 441 |
+
("wav2vec2-bert", "Wav2Vec2BertConfig"),
|
| 442 |
+
("wav2vec2-conformer", "Wav2Vec2ConformerConfig"),
|
| 443 |
+
("wavlm", "WavLMConfig"),
|
| 444 |
+
("whisper", "WhisperConfig"),
|
| 445 |
+
("xclip", "XCLIPConfig"),
|
| 446 |
+
("xcodec", "XcodecConfig"),
|
| 447 |
+
("xglm", "XGLMConfig"),
|
| 448 |
+
("xlm", "XLMConfig"),
|
| 449 |
+
("xlm-prophetnet", "XLMProphetNetConfig"),
|
| 450 |
+
("xlm-roberta", "XLMRobertaConfig"),
|
| 451 |
+
("xlm-roberta-xl", "XLMRobertaXLConfig"),
|
| 452 |
+
("xlnet", "XLNetConfig"),
|
| 453 |
+
("xlstm", "xLSTMConfig"),
|
| 454 |
+
("xmod", "XmodConfig"),
|
| 455 |
+
("yolos", "YolosConfig"),
|
| 456 |
+
("yoso", "YosoConfig"),
|
| 457 |
+
("zamba", "ZambaConfig"),
|
| 458 |
+
("zamba2", "Zamba2Config"),
|
| 459 |
+
("zoedepth", "ZoeDepthConfig"),
|
| 460 |
+
]
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
| 465 |
+
[
|
| 466 |
+
# Add full (and cased) model names here
|
| 467 |
+
("aimv2", "AIMv2"),
|
| 468 |
+
("aimv2_vision_model", "Aimv2VisionModel"),
|
| 469 |
+
("albert", "ALBERT"),
|
| 470 |
+
("align", "ALIGN"),
|
| 471 |
+
("altclip", "AltCLIP"),
|
| 472 |
+
("apertus", "Apertus"),
|
| 473 |
+
("arcee", "Arcee"),
|
| 474 |
+
("aria", "Aria"),
|
| 475 |
+
("aria_text", "AriaText"),
|
| 476 |
+
("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
|
| 477 |
+
("autoformer", "Autoformer"),
|
| 478 |
+
("aya_vision", "AyaVision"),
|
| 479 |
+
("bamba", "Bamba"),
|
| 480 |
+
("bark", "Bark"),
|
| 481 |
+
("bart", "BART"),
|
| 482 |
+
("barthez", "BARThez"),
|
| 483 |
+
("bartpho", "BARTpho"),
|
| 484 |
+
("beit", "BEiT"),
|
| 485 |
+
("bert", "BERT"),
|
| 486 |
+
("bert-generation", "Bert Generation"),
|
| 487 |
+
("bert-japanese", "BertJapanese"),
|
| 488 |
+
("bertweet", "BERTweet"),
|
| 489 |
+
("big_bird", "BigBird"),
|
| 490 |
+
("bigbird_pegasus", "BigBird-Pegasus"),
|
| 491 |
+
("biogpt", "BioGpt"),
|
| 492 |
+
("bit", "BiT"),
|
| 493 |
+
("bitnet", "BitNet"),
|
| 494 |
+
("blenderbot", "Blenderbot"),
|
| 495 |
+
("blenderbot-small", "BlenderbotSmall"),
|
| 496 |
+
("blip", "BLIP"),
|
| 497 |
+
("blip-2", "BLIP-2"),
|
| 498 |
+
("blip_2_qformer", "BLIP-2 QFormer"),
|
| 499 |
+
("bloom", "BLOOM"),
|
| 500 |
+
("blt", "Blt"),
|
| 501 |
+
("bort", "BORT"),
|
| 502 |
+
("bridgetower", "BridgeTower"),
|
| 503 |
+
("bros", "BROS"),
|
| 504 |
+
("byt5", "ByT5"),
|
| 505 |
+
("camembert", "CamemBERT"),
|
| 506 |
+
("canine", "CANINE"),
|
| 507 |
+
("chameleon", "Chameleon"),
|
| 508 |
+
("chinese_clip", "Chinese-CLIP"),
|
| 509 |
+
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
|
| 510 |
+
("clap", "CLAP"),
|
| 511 |
+
("clip", "CLIP"),
|
| 512 |
+
("clip_text_model", "CLIPTextModel"),
|
| 513 |
+
("clip_vision_model", "CLIPVisionModel"),
|
| 514 |
+
("clipseg", "CLIPSeg"),
|
| 515 |
+
("clvp", "CLVP"),
|
| 516 |
+
("code_llama", "CodeLlama"),
|
| 517 |
+
("codegen", "CodeGen"),
|
| 518 |
+
("cohere", "Cohere"),
|
| 519 |
+
("cohere2", "Cohere2"),
|
| 520 |
+
("cohere2_vision", "Cohere2Vision"),
|
| 521 |
+
("colpali", "ColPali"),
|
| 522 |
+
("colqwen2", "ColQwen2"),
|
| 523 |
+
("conditional_detr", "Conditional DETR"),
|
| 524 |
+
("convbert", "ConvBERT"),
|
| 525 |
+
("convnext", "ConvNeXT"),
|
| 526 |
+
("convnextv2", "ConvNeXTV2"),
|
| 527 |
+
("cpm", "CPM"),
|
| 528 |
+
("cpmant", "CPM-Ant"),
|
| 529 |
+
("csm", "CSM"),
|
| 530 |
+
("ctrl", "CTRL"),
|
| 531 |
+
("cvt", "CvT"),
|
| 532 |
+
("d_fine", "D-FINE"),
|
| 533 |
+
("dab-detr", "DAB-DETR"),
|
| 534 |
+
("dac", "DAC"),
|
| 535 |
+
("data2vec-audio", "Data2VecAudio"),
|
| 536 |
+
("data2vec-text", "Data2VecText"),
|
| 537 |
+
("data2vec-vision", "Data2VecVision"),
|
| 538 |
+
("dbrx", "DBRX"),
|
| 539 |
+
("deberta", "DeBERTa"),
|
| 540 |
+
("deberta-v2", "DeBERTa-v2"),
|
| 541 |
+
("decision_transformer", "Decision Transformer"),
|
| 542 |
+
("deepseek_v2", "DeepSeek-V2"),
|
| 543 |
+
("deepseek_v3", "DeepSeek-V3"),
|
| 544 |
+
("deepseek_vl", "DeepseekVL"),
|
| 545 |
+
("deepseek_vl_hybrid", "DeepseekVLHybrid"),
|
| 546 |
+
("deformable_detr", "Deformable DETR"),
|
| 547 |
+
("deit", "DeiT"),
|
| 548 |
+
("deplot", "DePlot"),
|
| 549 |
+
("depth_anything", "Depth Anything"),
|
| 550 |
+
("depth_anything_v2", "Depth Anything V2"),
|
| 551 |
+
("depth_pro", "DepthPro"),
|
| 552 |
+
("deta", "DETA"),
|
| 553 |
+
("detr", "DETR"),
|
| 554 |
+
("dia", "Dia"),
|
| 555 |
+
("dialogpt", "DialoGPT"),
|
| 556 |
+
("diffllama", "DiffLlama"),
|
| 557 |
+
("dinat", "DiNAT"),
|
| 558 |
+
("dinov2", "DINOv2"),
|
| 559 |
+
("dinov2_with_registers", "DINOv2 with Registers"),
|
| 560 |
+
("dinov3_convnext", "DINOv3 ConvNext"),
|
| 561 |
+
("dinov3_vit", "DINOv3 ViT"),
|
| 562 |
+
("distilbert", "DistilBERT"),
|
| 563 |
+
("dit", "DiT"),
|
| 564 |
+
("doge", "Doge"),
|
| 565 |
+
("donut-swin", "DonutSwin"),
|
| 566 |
+
("dots1", "dots1"),
|
| 567 |
+
("dpr", "DPR"),
|
| 568 |
+
("dpt", "DPT"),
|
| 569 |
+
("edgetam", "EdgeTAM"),
|
| 570 |
+
("edgetam_video", "EdgeTamVideo"),
|
| 571 |
+
("edgetam_vision_model", "EdgeTamVisionModel"),
|
| 572 |
+
("efficientformer", "EfficientFormer"),
|
| 573 |
+
("efficientloftr", "EfficientLoFTR"),
|
| 574 |
+
("efficientnet", "EfficientNet"),
|
| 575 |
+
("electra", "ELECTRA"),
|
| 576 |
+
("emu3", "Emu3"),
|
| 577 |
+
("encodec", "EnCodec"),
|
| 578 |
+
("encoder-decoder", "Encoder decoder"),
|
| 579 |
+
("eomt", "EoMT"),
|
| 580 |
+
("ernie", "ERNIE"),
|
| 581 |
+
("ernie4_5", "Ernie4_5"),
|
| 582 |
+
("ernie4_5_moe", "Ernie4_5_MoE"),
|
| 583 |
+
("ernie_m", "ErnieM"),
|
| 584 |
+
("esm", "ESM"),
|
| 585 |
+
("evolla", "Evolla"),
|
| 586 |
+
("exaone4", "EXAONE-4.0"),
|
| 587 |
+
("falcon", "Falcon"),
|
| 588 |
+
("falcon3", "Falcon3"),
|
| 589 |
+
("falcon_h1", "FalconH1"),
|
| 590 |
+
("falcon_mamba", "FalconMamba"),
|
| 591 |
+
("fastspeech2_conformer", "FastSpeech2Conformer"),
|
| 592 |
+
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
|
| 593 |
+
("flan-t5", "FLAN-T5"),
|
| 594 |
+
("flan-ul2", "FLAN-UL2"),
|
| 595 |
+
("flaubert", "FlauBERT"),
|
| 596 |
+
("flava", "FLAVA"),
|
| 597 |
+
("flex_olmo", "FlexOlmo"),
|
| 598 |
+
("florence2", "Florence2"),
|
| 599 |
+
("fnet", "FNet"),
|
| 600 |
+
("focalnet", "FocalNet"),
|
| 601 |
+
("fsmt", "FairSeq Machine-Translation"),
|
| 602 |
+
("funnel", "Funnel Transformer"),
|
| 603 |
+
("fuyu", "Fuyu"),
|
| 604 |
+
("gemma", "Gemma"),
|
| 605 |
+
("gemma2", "Gemma2"),
|
| 606 |
+
("gemma3", "Gemma3ForConditionalGeneration"),
|
| 607 |
+
("gemma3_text", "Gemma3ForCausalLM"),
|
| 608 |
+
("gemma3n", "Gemma3nForConditionalGeneration"),
|
| 609 |
+
("gemma3n_audio", "Gemma3nAudioEncoder"),
|
| 610 |
+
("gemma3n_text", "Gemma3nForCausalLM"),
|
| 611 |
+
("gemma3n_vision", "TimmWrapperModel"),
|
| 612 |
+
("git", "GIT"),
|
| 613 |
+
("glm", "GLM"),
|
| 614 |
+
("glm4", "GLM4"),
|
| 615 |
+
("glm4_moe", "Glm4MoE"),
|
| 616 |
+
("glm4v", "GLM4V"),
|
| 617 |
+
("glm4v_moe", "GLM4VMOE"),
|
| 618 |
+
("glm4v_moe_text", "GLM4VMOE"),
|
| 619 |
+
("glm4v_text", "GLM4V"),
|
| 620 |
+
("glpn", "GLPN"),
|
| 621 |
+
("got_ocr2", "GOT-OCR2"),
|
| 622 |
+
("gpt-sw3", "GPT-Sw3"),
|
| 623 |
+
("gpt2", "OpenAI GPT-2"),
|
| 624 |
+
("gpt_bigcode", "GPTBigCode"),
|
| 625 |
+
("gpt_neo", "GPT Neo"),
|
| 626 |
+
("gpt_neox", "GPT NeoX"),
|
| 627 |
+
("gpt_neox_japanese", "GPT NeoX Japanese"),
|
| 628 |
+
("gpt_oss", "GptOss"),
|
| 629 |
+
("gptj", "GPT-J"),
|
| 630 |
+
("gptsan-japanese", "GPTSAN-japanese"),
|
| 631 |
+
("granite", "Granite"),
|
| 632 |
+
("granite_speech", "GraniteSpeech"),
|
| 633 |
+
("granitemoe", "GraniteMoeMoe"),
|
| 634 |
+
("granitemoehybrid", "GraniteMoeHybrid"),
|
| 635 |
+
("granitemoeshared", "GraniteMoeSharedMoe"),
|
| 636 |
+
("granitevision", "LLaVA-NeXT"),
|
| 637 |
+
("graphormer", "Graphormer"),
|
| 638 |
+
("grounding-dino", "Grounding DINO"),
|
| 639 |
+
("groupvit", "GroupViT"),
|
| 640 |
+
("helium", "Helium"),
|
| 641 |
+
("herbert", "HerBERT"),
|
| 642 |
+
("hgnet_v2", "HGNet-V2"),
|
| 643 |
+
("hiera", "Hiera"),
|
| 644 |
+
("hubert", "Hubert"),
|
| 645 |
+
("hunyuan_v1_dense", "HunYuanDenseV1"),
|
| 646 |
+
("hunyuan_v1_moe", "HunYuanMoeV1"),
|
| 647 |
+
("ibert", "I-BERT"),
|
| 648 |
+
("idefics", "IDEFICS"),
|
| 649 |
+
("idefics2", "Idefics2"),
|
| 650 |
+
("idefics3", "Idefics3"),
|
| 651 |
+
("idefics3_vision", "Idefics3VisionTransformer"),
|
| 652 |
+
("ijepa", "I-JEPA"),
|
| 653 |
+
("imagegpt", "ImageGPT"),
|
| 654 |
+
("informer", "Informer"),
|
| 655 |
+
("instructblip", "InstructBLIP"),
|
| 656 |
+
("instructblipvideo", "InstructBlipVideo"),
|
| 657 |
+
("internvl", "InternVL"),
|
| 658 |
+
("internvl_vision", "InternVLVision"),
|
| 659 |
+
("jamba", "Jamba"),
|
| 660 |
+
("janus", "Janus"),
|
| 661 |
+
("jetmoe", "JetMoe"),
|
| 662 |
+
("jukebox", "Jukebox"),
|
| 663 |
+
("kosmos-2", "KOSMOS-2"),
|
| 664 |
+
("kosmos-2.5", "KOSMOS-2.5"),
|
| 665 |
+
("kyutai_speech_to_text", "KyutaiSpeechToText"),
|
| 666 |
+
("layoutlm", "LayoutLM"),
|
| 667 |
+
("layoutlmv2", "LayoutLMv2"),
|
| 668 |
+
("layoutlmv3", "LayoutLMv3"),
|
| 669 |
+
("layoutxlm", "LayoutXLM"),
|
| 670 |
+
("led", "LED"),
|
| 671 |
+
("levit", "LeViT"),
|
| 672 |
+
("lfm2", "Lfm2"),
|
| 673 |
+
("lfm2_vl", "Lfm2Vl"),
|
| 674 |
+
("lightglue", "LightGlue"),
|
| 675 |
+
("lilt", "LiLT"),
|
| 676 |
+
("llama", "LLaMA"),
|
| 677 |
+
("llama2", "Llama2"),
|
| 678 |
+
("llama3", "Llama3"),
|
| 679 |
+
("llama4", "Llama4"),
|
| 680 |
+
("llama4_text", "Llama4ForCausalLM"),
|
| 681 |
+
("llava", "LLaVa"),
|
| 682 |
+
("llava_next", "LLaVA-NeXT"),
|
| 683 |
+
("llava_next_video", "LLaVa-NeXT-Video"),
|
| 684 |
+
("llava_onevision", "LLaVA-Onevision"),
|
| 685 |
+
("longcat_flash", "LongCatFlash"),
|
| 686 |
+
("longformer", "Longformer"),
|
| 687 |
+
("longt5", "LongT5"),
|
| 688 |
+
("luke", "LUKE"),
|
| 689 |
+
("lxmert", "LXMERT"),
|
| 690 |
+
("m2m_100", "M2M100"),
|
| 691 |
+
("madlad-400", "MADLAD-400"),
|
| 692 |
+
("mamba", "Mamba"),
|
| 693 |
+
("mamba2", "mamba2"),
|
| 694 |
+
("marian", "Marian"),
|
| 695 |
+
("markuplm", "MarkupLM"),
|
| 696 |
+
("mask2former", "Mask2Former"),
|
| 697 |
+
("maskformer", "MaskFormer"),
|
| 698 |
+
("maskformer-swin", "MaskFormerSwin"),
|
| 699 |
+
("matcha", "MatCha"),
|
| 700 |
+
("mbart", "mBART"),
|
| 701 |
+
("mbart50", "mBART-50"),
|
| 702 |
+
("mctct", "M-CTC-T"),
|
| 703 |
+
("mega", "MEGA"),
|
| 704 |
+
("megatron-bert", "Megatron-BERT"),
|
| 705 |
+
("megatron_gpt2", "Megatron-GPT2"),
|
| 706 |
+
("metaclip_2", "MetaCLIP 2"),
|
| 707 |
+
("mgp-str", "MGP-STR"),
|
| 708 |
+
("mimi", "Mimi"),
|
| 709 |
+
("minimax", "MiniMax"),
|
| 710 |
+
("ministral", "Ministral"),
|
| 711 |
+
("mistral", "Mistral"),
|
| 712 |
+
("mistral3", "Mistral3"),
|
| 713 |
+
("mixtral", "Mixtral"),
|
| 714 |
+
("mlcd", "MLCD"),
|
| 715 |
+
("mllama", "Mllama"),
|
| 716 |
+
("mluke", "mLUKE"),
|
| 717 |
+
("mm-grounding-dino", "MM Grounding DINO"),
|
| 718 |
+
("mms", "MMS"),
|
| 719 |
+
("mobilebert", "MobileBERT"),
|
| 720 |
+
("mobilenet_v1", "MobileNetV1"),
|
| 721 |
+
("mobilenet_v2", "MobileNetV2"),
|
| 722 |
+
("mobilevit", "MobileViT"),
|
| 723 |
+
("mobilevitv2", "MobileViTV2"),
|
| 724 |
+
("modernbert", "ModernBERT"),
|
| 725 |
+
("modernbert-decoder", "ModernBertDecoder"),
|
| 726 |
+
("moonshine", "Moonshine"),
|
| 727 |
+
("moshi", "Moshi"),
|
| 728 |
+
("mpnet", "MPNet"),
|
| 729 |
+
("mpt", "MPT"),
|
| 730 |
+
("mra", "MRA"),
|
| 731 |
+
("mt5", "MT5"),
|
| 732 |
+
("musicgen", "MusicGen"),
|
| 733 |
+
("musicgen_melody", "MusicGen Melody"),
|
| 734 |
+
("mvp", "MVP"),
|
| 735 |
+
("myt5", "myt5"),
|
| 736 |
+
("nat", "NAT"),
|
| 737 |
+
("nemotron", "Nemotron"),
|
| 738 |
+
("nezha", "Nezha"),
|
| 739 |
+
("nllb", "NLLB"),
|
| 740 |
+
("nllb-moe", "NLLB-MOE"),
|
| 741 |
+
("nougat", "Nougat"),
|
| 742 |
+
("nystromformer", "Nyströmformer"),
|
| 743 |
+
("olmo", "OLMo"),
|
| 744 |
+
("olmo2", "OLMo2"),
|
| 745 |
+
("olmo3", "Olmo3"),
|
| 746 |
+
("olmoe", "OLMoE"),
|
| 747 |
+
("omdet-turbo", "OmDet-Turbo"),
|
| 748 |
+
("oneformer", "OneFormer"),
|
| 749 |
+
("open-llama", "OpenLlama"),
|
| 750 |
+
("openai-gpt", "OpenAI GPT"),
|
| 751 |
+
("opt", "OPT"),
|
| 752 |
+
("ovis2", "Ovis2"),
|
| 753 |
+
("owlv2", "OWLv2"),
|
| 754 |
+
("owlvit", "OWL-ViT"),
|
| 755 |
+
("paligemma", "PaliGemma"),
|
| 756 |
+
("parakeet", "Parakeet"),
|
| 757 |
+
("parakeet_ctc", "Parakeet"),
|
| 758 |
+
("parakeet_encoder", "ParakeetEncoder"),
|
| 759 |
+
("patchtsmixer", "PatchTSMixer"),
|
| 760 |
+
("patchtst", "PatchTST"),
|
| 761 |
+
("pegasus", "Pegasus"),
|
| 762 |
+
("pegasus_x", "PEGASUS-X"),
|
| 763 |
+
("perceiver", "Perceiver"),
|
| 764 |
+
("perception_encoder", "PerceptionEncoder"),
|
| 765 |
+
("perception_lm", "PerceptionLM"),
|
| 766 |
+
("persimmon", "Persimmon"),
|
| 767 |
+
("phi", "Phi"),
|
| 768 |
+
("phi3", "Phi3"),
|
| 769 |
+
("phi4_multimodal", "Phi4Multimodal"),
|
| 770 |
+
("phimoe", "Phimoe"),
|
| 771 |
+
("phobert", "PhoBERT"),
|
| 772 |
+
("pix2struct", "Pix2Struct"),
|
| 773 |
+
("pixtral", "Pixtral"),
|
| 774 |
+
("plbart", "PLBart"),
|
| 775 |
+
("poolformer", "PoolFormer"),
|
| 776 |
+
("pop2piano", "Pop2Piano"),
|
| 777 |
+
("prompt_depth_anything", "PromptDepthAnything"),
|
| 778 |
+
("prophetnet", "ProphetNet"),
|
| 779 |
+
("pvt", "PVT"),
|
| 780 |
+
("pvt_v2", "PVTv2"),
|
| 781 |
+
("qdqbert", "QDQBert"),
|
| 782 |
+
("qwen2", "Qwen2"),
|
| 783 |
+
("qwen2_5_omni", "Qwen2_5Omni"),
|
| 784 |
+
("qwen2_5_vl", "Qwen2_5_VL"),
|
| 785 |
+
("qwen2_5_vl_text", "Qwen2_5_VL"),
|
| 786 |
+
("qwen2_audio", "Qwen2Audio"),
|
| 787 |
+
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
|
| 788 |
+
("qwen2_moe", "Qwen2MoE"),
|
| 789 |
+
("qwen2_vl", "Qwen2VL"),
|
| 790 |
+
("qwen2_vl_text", "Qwen2VL"),
|
| 791 |
+
("qwen3", "Qwen3"),
|
| 792 |
+
("qwen3_moe", "Qwen3MoE"),
|
| 793 |
+
("qwen3_next", "Qwen3Next"),
|
| 794 |
+
("qwen3_omni_moe", "Qwen3OmniMoE"),
|
| 795 |
+
("qwen3_vl", "Qwen3VL"),
|
| 796 |
+
("qwen3_vl_moe", "Qwen3VLMoe"),
|
| 797 |
+
("qwen3_vl_moe_text", "Qwen3VLMoe"),
|
| 798 |
+
("qwen3_vl_text", "Qwen3VL"),
|
| 799 |
+
("rag", "RAG"),
|
| 800 |
+
("realm", "REALM"),
|
| 801 |
+
("recurrent_gemma", "RecurrentGemma"),
|
| 802 |
+
("reformer", "Reformer"),
|
| 803 |
+
("regnet", "RegNet"),
|
| 804 |
+
("rembert", "RemBERT"),
|
| 805 |
+
("resnet", "ResNet"),
|
| 806 |
+
("retribert", "RetriBERT"),
|
| 807 |
+
("roberta", "RoBERTa"),
|
| 808 |
+
("roberta-prelayernorm", "RoBERTa-PreLayerNorm"),
|
| 809 |
+
("roc_bert", "RoCBert"),
|
| 810 |
+
("roformer", "RoFormer"),
|
| 811 |
+
("rt_detr", "RT-DETR"),
|
| 812 |
+
("rt_detr_resnet", "RT-DETR-ResNet"),
|
| 813 |
+
("rt_detr_v2", "RT-DETRv2"),
|
| 814 |
+
("rwkv", "RWKV"),
|
| 815 |
+
("sam", "SAM"),
|
| 816 |
+
("sam2", "SAM2"),
|
| 817 |
+
("sam2_hiera_det_model", "Sam2HieraDetModel"),
|
| 818 |
+
("sam2_video", "Sam2VideoModel"),
|
| 819 |
+
("sam2_vision_model", "Sam2VisionModel"),
|
| 820 |
+
("sam_hq", "SAM-HQ"),
|
| 821 |
+
("sam_hq_vision_model", "SamHQVisionModel"),
|
| 822 |
+
("sam_vision_model", "SamVisionModel"),
|
| 823 |
+
("seamless_m4t", "SeamlessM4T"),
|
| 824 |
+
("seamless_m4t_v2", "SeamlessM4Tv2"),
|
| 825 |
+
("seed_oss", "SeedOss"),
|
| 826 |
+
("segformer", "SegFormer"),
|
| 827 |
+
("seggpt", "SegGPT"),
|
| 828 |
+
("sew", "SEW"),
|
| 829 |
+
("sew-d", "SEW-D"),
|
| 830 |
+
("shieldgemma2", "Shieldgemma2"),
|
| 831 |
+
("siglip", "SigLIP"),
|
| 832 |
+
("siglip2", "SigLIP2"),
|
| 833 |
+
("siglip2_vision_model", "Siglip2VisionModel"),
|
| 834 |
+
("siglip_vision_model", "SiglipVisionModel"),
|
| 835 |
+
("smollm3", "SmolLM3"),
|
| 836 |
+
("smolvlm", "SmolVLM"),
|
| 837 |
+
("smolvlm_vision", "SmolVLMVisionTransformer"),
|
| 838 |
+
("speech-encoder-decoder", "Speech Encoder decoder"),
|
| 839 |
+
("speech_to_text", "Speech2Text"),
|
| 840 |
+
("speech_to_text_2", "Speech2Text2"),
|
| 841 |
+
("speecht5", "SpeechT5"),
|
| 842 |
+
("splinter", "Splinter"),
|
| 843 |
+
("squeezebert", "SqueezeBERT"),
|
| 844 |
+
("stablelm", "StableLm"),
|
| 845 |
+
("starcoder2", "Starcoder2"),
|
| 846 |
+
("superglue", "SuperGlue"),
|
| 847 |
+
("superpoint", "SuperPoint"),
|
| 848 |
+
("swiftformer", "SwiftFormer"),
|
| 849 |
+
("swin", "Swin Transformer"),
|
| 850 |
+
("swin2sr", "Swin2SR"),
|
| 851 |
+
("swinv2", "Swin Transformer V2"),
|
| 852 |
+
("switch_transformers", "SwitchTransformers"),
|
| 853 |
+
("t5", "T5"),
|
| 854 |
+
("t5gemma", "T5Gemma"),
|
| 855 |
+
("t5v1.1", "T5v1.1"),
|
| 856 |
+
("table-transformer", "Table Transformer"),
|
| 857 |
+
("tapas", "TAPAS"),
|
| 858 |
+
("tapex", "TAPEX"),
|
| 859 |
+
("textnet", "TextNet"),
|
| 860 |
+
("time_series_transformer", "Time Series Transformer"),
|
| 861 |
+
("timesfm", "TimesFm"),
|
| 862 |
+
("timesformer", "TimeSformer"),
|
| 863 |
+
("timm_backbone", "TimmBackbone"),
|
| 864 |
+
("timm_wrapper", "TimmWrapperModel"),
|
| 865 |
+
("trajectory_transformer", "Trajectory Transformer"),
|
| 866 |
+
("transfo-xl", "Transformer-XL"),
|
| 867 |
+
("trocr", "TrOCR"),
|
| 868 |
+
("tvlt", "TVLT"),
|
| 869 |
+
("tvp", "TVP"),
|
| 870 |
+
("udop", "UDOP"),
|
| 871 |
+
("ul2", "UL2"),
|
| 872 |
+
("umt5", "UMT5"),
|
| 873 |
+
("unispeech", "UniSpeech"),
|
| 874 |
+
("unispeech-sat", "UniSpeechSat"),
|
| 875 |
+
("univnet", "UnivNet"),
|
| 876 |
+
("upernet", "UPerNet"),
|
| 877 |
+
("van", "VAN"),
|
| 878 |
+
("vaultgemma", "VaultGemma"),
|
| 879 |
+
("video_llava", "VideoLlava"),
|
| 880 |
+
("videomae", "VideoMAE"),
|
| 881 |
+
("vilt", "ViLT"),
|
| 882 |
+
("vipllava", "VipLlava"),
|
| 883 |
+
("vision-encoder-decoder", "Vision Encoder decoder"),
|
| 884 |
+
("vision-text-dual-encoder", "VisionTextDualEncoder"),
|
| 885 |
+
("visual_bert", "VisualBERT"),
|
| 886 |
+
("vit", "ViT"),
|
| 887 |
+
("vit_hybrid", "ViT Hybrid"),
|
| 888 |
+
("vit_mae", "ViTMAE"),
|
| 889 |
+
("vit_msn", "ViTMSN"),
|
| 890 |
+
("vitdet", "VitDet"),
|
| 891 |
+
("vitmatte", "ViTMatte"),
|
| 892 |
+
("vitpose", "ViTPose"),
|
| 893 |
+
("vitpose_backbone", "ViTPoseBackbone"),
|
| 894 |
+
("vits", "VITS"),
|
| 895 |
+
("vivit", "ViViT"),
|
| 896 |
+
("vjepa2", "VJEPA2Model"),
|
| 897 |
+
("voxtral", "Voxtral"),
|
| 898 |
+
("voxtral_encoder", "Voxtral Encoder"),
|
| 899 |
+
("wav2vec2", "Wav2Vec2"),
|
| 900 |
+
("wav2vec2-bert", "Wav2Vec2-BERT"),
|
| 901 |
+
("wav2vec2-conformer", "Wav2Vec2-Conformer"),
|
| 902 |
+
("wav2vec2_phoneme", "Wav2Vec2Phoneme"),
|
| 903 |
+
("wavlm", "WavLM"),
|
| 904 |
+
("whisper", "Whisper"),
|
| 905 |
+
("xclip", "X-CLIP"),
|
| 906 |
+
("xcodec", "X-CODEC"),
|
| 907 |
+
("xglm", "XGLM"),
|
| 908 |
+
("xlm", "XLM"),
|
| 909 |
+
("xlm-prophetnet", "XLM-ProphetNet"),
|
| 910 |
+
("xlm-roberta", "XLM-RoBERTa"),
|
| 911 |
+
("xlm-roberta-xl", "XLM-RoBERTa-XL"),
|
| 912 |
+
("xlm-v", "XLM-V"),
|
| 913 |
+
("xlnet", "XLNet"),
|
| 914 |
+
("xls_r", "XLS-R"),
|
| 915 |
+
("xlsr_wav2vec2", "XLSR-Wav2Vec2"),
|
| 916 |
+
("xlstm", "xLSTM"),
|
| 917 |
+
("xmod", "X-MOD"),
|
| 918 |
+
("yolos", "YOLOS"),
|
| 919 |
+
("yoso", "YOSO"),
|
| 920 |
+
("zamba", "Zamba"),
|
| 921 |
+
("zamba2", "Zamba2"),
|
| 922 |
+
("zoedepth", "ZoeDepth"),
|
| 923 |
+
]
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
# This is tied to the processing `-` -> `_` in `model_type_to_module_name`. For example, instead of putting
|
| 927 |
+
# `transfo-xl` (as in `CONFIG_MAPPING_NAMES`), we should use `transfo_xl`.
|
| 928 |
+
DEPRECATED_MODELS = [
|
| 929 |
+
"bort",
|
| 930 |
+
"deta",
|
| 931 |
+
"efficientformer",
|
| 932 |
+
"ernie_m",
|
| 933 |
+
"gptsan_japanese",
|
| 934 |
+
"graphormer",
|
| 935 |
+
"jukebox",
|
| 936 |
+
"mctct",
|
| 937 |
+
"mega",
|
| 938 |
+
"mmbt",
|
| 939 |
+
"nat",
|
| 940 |
+
"nezha",
|
| 941 |
+
"open_llama",
|
| 942 |
+
"qdqbert",
|
| 943 |
+
"realm",
|
| 944 |
+
"retribert",
|
| 945 |
+
"speech_to_text_2",
|
| 946 |
+
"tapex",
|
| 947 |
+
"trajectory_transformer",
|
| 948 |
+
"transfo_xl",
|
| 949 |
+
"tvlt",
|
| 950 |
+
"van",
|
| 951 |
+
"vit_hybrid",
|
| 952 |
+
"xlm_prophetnet",
|
| 953 |
+
]
|
| 954 |
+
|
| 955 |
+
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
|
| 956 |
+
[
|
| 957 |
+
("openai-gpt", "openai"),
|
| 958 |
+
("data2vec-audio", "data2vec"),
|
| 959 |
+
("data2vec-text", "data2vec"),
|
| 960 |
+
("data2vec-vision", "data2vec"),
|
| 961 |
+
("donut-swin", "donut"),
|
| 962 |
+
("kosmos-2", "kosmos2"),
|
| 963 |
+
("kosmos-2.5", "kosmos2_5"),
|
| 964 |
+
("maskformer-swin", "maskformer"),
|
| 965 |
+
("xclip", "x_clip"),
|
| 966 |
+
("clip_vision_model", "clip"),
|
| 967 |
+
("qwen2_audio_encoder", "qwen2_audio"),
|
| 968 |
+
("voxtral_encoder", "voxtral"),
|
| 969 |
+
("clip_text_model", "clip"),
|
| 970 |
+
("aria_text", "aria"),
|
| 971 |
+
("gemma3_text", "gemma3"),
|
| 972 |
+
("gemma3n_audio", "gemma3n"),
|
| 973 |
+
("gemma3n_text", "gemma3n"),
|
| 974 |
+
("gemma3n_vision", "gemma3n"),
|
| 975 |
+
("glm4v_text", "glm4v"),
|
| 976 |
+
("glm4v_moe_text", "glm4v_moe"),
|
| 977 |
+
("idefics3_vision", "idefics3"),
|
| 978 |
+
("siglip_vision_model", "siglip"),
|
| 979 |
+
("siglip2_vision_model", "siglip2"),
|
| 980 |
+
("aimv2_vision_model", "aimv2"),
|
| 981 |
+
("smolvlm_vision", "smolvlm"),
|
| 982 |
+
("chinese_clip_vision_model", "chinese_clip"),
|
| 983 |
+
("rt_detr_resnet", "rt_detr"),
|
| 984 |
+
("granitevision", "llava_next"),
|
| 985 |
+
("internvl_vision", "internvl"),
|
| 986 |
+
("qwen2_5_vl_text", "qwen2_5_vl"),
|
| 987 |
+
("qwen2_vl_text", "qwen2_vl"),
|
| 988 |
+
("qwen3_vl_text", "qwen3_vl"),
|
| 989 |
+
("qwen3_vl_moe_text", "qwen3_vl_moe"),
|
| 990 |
+
("sam_vision_model", "sam"),
|
| 991 |
+
("sam2_vision_model", "sam2"),
|
| 992 |
+
("edgetam_vision_model", "edgetam"),
|
| 993 |
+
("sam2_hiera_det_model", "sam2"),
|
| 994 |
+
("sam_hq_vision_model", "sam_hq"),
|
| 995 |
+
("llama4_text", "llama4"),
|
| 996 |
+
("blip_2_qformer", "blip_2"),
|
| 997 |
+
("fastspeech2_conformer_with_hifigan", "fastspeech2_conformer"),
|
| 998 |
+
("perception_encoder", "perception_lm"),
|
| 999 |
+
("parakeet_encoder", "parakeet"),
|
| 1000 |
+
("parakeet_ctc", "parakeet"),
|
| 1001 |
+
]
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
def model_type_to_module_name(key) -> str:
|
| 1006 |
+
"""Converts a config key to the corresponding module."""
|
| 1007 |
+
# Special treatment
|
| 1008 |
+
if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
|
| 1009 |
+
key = SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]
|
| 1010 |
+
|
| 1011 |
+
if key in DEPRECATED_MODELS:
|
| 1012 |
+
key = f"deprecated.{key}"
|
| 1013 |
+
return key
|
| 1014 |
+
|
| 1015 |
+
key = key.replace("-", "_")
|
| 1016 |
+
if key in DEPRECATED_MODELS:
|
| 1017 |
+
key = f"deprecated.{key}"
|
| 1018 |
+
|
| 1019 |
+
return key
|
| 1020 |
+
|
| 1021 |
+
|
| 1022 |
+
def config_class_to_model_type(config) -> Union[str, None]:
|
| 1023 |
+
"""Converts a config class name to the corresponding model type"""
|
| 1024 |
+
for key, cls in CONFIG_MAPPING_NAMES.items():
|
| 1025 |
+
if cls == config:
|
| 1026 |
+
return key
|
| 1027 |
+
# if key not found check in extra content
|
| 1028 |
+
for key, cls in CONFIG_MAPPING._extra_content.items():
|
| 1029 |
+
if cls.__name__ == config:
|
| 1030 |
+
return key
|
| 1031 |
+
return None
|
| 1032 |
+
|
| 1033 |
+
|
| 1034 |
+
class _LazyConfigMapping(OrderedDict[str, type[PretrainedConfig]]):
|
| 1035 |
+
"""
|
| 1036 |
+
A dictionary that lazily load its values when they are requested.
|
| 1037 |
+
"""
|
| 1038 |
+
|
| 1039 |
+
def __init__(self, mapping) -> None:
|
| 1040 |
+
self._mapping = mapping
|
| 1041 |
+
self._extra_content = {}
|
| 1042 |
+
self._modules = {}
|
| 1043 |
+
|
| 1044 |
+
def __getitem__(self, key: str) -> type[PretrainedConfig]:
|
| 1045 |
+
if key in self._extra_content:
|
| 1046 |
+
return self._extra_content[key]
|
| 1047 |
+
if key not in self._mapping:
|
| 1048 |
+
raise KeyError(key)
|
| 1049 |
+
value = self._mapping[key]
|
| 1050 |
+
module_name = model_type_to_module_name(key)
|
| 1051 |
+
if module_name not in self._modules:
|
| 1052 |
+
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
|
| 1053 |
+
if hasattr(self._modules[module_name], value):
|
| 1054 |
+
return getattr(self._modules[module_name], value)
|
| 1055 |
+
|
| 1056 |
+
# Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the
|
| 1057 |
+
# object at the top level.
|
| 1058 |
+
transformers_module = importlib.import_module("transformers")
|
| 1059 |
+
return getattr(transformers_module, value)
|
| 1060 |
+
|
| 1061 |
+
def keys(self) -> list[str]:
|
| 1062 |
+
return list(self._mapping.keys()) + list(self._extra_content.keys())
|
| 1063 |
+
|
| 1064 |
+
def values(self) -> list[type[PretrainedConfig]]:
|
| 1065 |
+
return [self[k] for k in self._mapping] + list(self._extra_content.values())
|
| 1066 |
+
|
| 1067 |
+
def items(self) -> list[tuple[str, type[PretrainedConfig]]]:
|
| 1068 |
+
return [(k, self[k]) for k in self._mapping] + list(self._extra_content.items())
|
| 1069 |
+
|
| 1070 |
+
def __iter__(self) -> Iterator[str]:
|
| 1071 |
+
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
|
| 1072 |
+
|
| 1073 |
+
def __contains__(self, item: object) -> bool:
|
| 1074 |
+
return item in self._mapping or item in self._extra_content
|
| 1075 |
+
|
| 1076 |
+
def register(self, key: str, value: type[PretrainedConfig], exist_ok=False) -> None:
|
| 1077 |
+
"""
|
| 1078 |
+
Register a new configuration in this mapping.
|
| 1079 |
+
"""
|
| 1080 |
+
if key in self._mapping and not exist_ok:
|
| 1081 |
+
raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
|
| 1082 |
+
self._extra_content[key] = value
|
| 1083 |
+
|
| 1084 |
+
|
| 1085 |
+
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
| 1086 |
+
|
| 1087 |
+
|
| 1088 |
+
class _LazyLoadAllMappings(OrderedDict[str, str]):
|
| 1089 |
+
"""
|
| 1090 |
+
A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
|
| 1091 |
+
etc.)
|
| 1092 |
+
|
| 1093 |
+
Args:
|
| 1094 |
+
mapping: The mapping to load.
|
| 1095 |
+
"""
|
| 1096 |
+
|
| 1097 |
+
def __init__(self, mapping):
|
| 1098 |
+
self._mapping = mapping
|
| 1099 |
+
self._initialized = False
|
| 1100 |
+
self._data = {}
|
| 1101 |
+
|
| 1102 |
+
def _initialize(self):
|
| 1103 |
+
if self._initialized:
|
| 1104 |
+
return
|
| 1105 |
+
|
| 1106 |
+
for model_type, map_name in self._mapping.items():
|
| 1107 |
+
module_name = model_type_to_module_name(model_type)
|
| 1108 |
+
module = importlib.import_module(f".{module_name}", "transformers.models")
|
| 1109 |
+
mapping = getattr(module, map_name)
|
| 1110 |
+
self._data.update(mapping)
|
| 1111 |
+
|
| 1112 |
+
self._initialized = True
|
| 1113 |
+
|
| 1114 |
+
def __getitem__(self, key):
|
| 1115 |
+
self._initialize()
|
| 1116 |
+
return self._data[key]
|
| 1117 |
+
|
| 1118 |
+
def keys(self) -> KeysView[str]:
|
| 1119 |
+
self._initialize()
|
| 1120 |
+
return self._data.keys()
|
| 1121 |
+
|
| 1122 |
+
def values(self) -> ValuesView[str]:
|
| 1123 |
+
self._initialize()
|
| 1124 |
+
return self._data.values()
|
| 1125 |
+
|
| 1126 |
+
def items(self) -> KeysView[str]:
|
| 1127 |
+
self._initialize()
|
| 1128 |
+
return self._data.keys()
|
| 1129 |
+
|
| 1130 |
+
def __iter__(self) -> Iterator[str]:
|
| 1131 |
+
self._initialize()
|
| 1132 |
+
return iter(self._data)
|
| 1133 |
+
|
| 1134 |
+
def __contains__(self, item: object) -> bool:
|
| 1135 |
+
self._initialize()
|
| 1136 |
+
return item in self._data
|
| 1137 |
+
|
| 1138 |
+
|
| 1139 |
+
def _get_class_name(model_class: Union[str, list[str]]):
|
| 1140 |
+
if isinstance(model_class, (list, tuple)):
|
| 1141 |
+
return " or ".join([f"[`{c}`]" for c in model_class if c is not None])
|
| 1142 |
+
return f"[`{model_class}`]"
|
| 1143 |
+
|
| 1144 |
+
|
| 1145 |
+
def _list_model_options(indent, config_to_class=None, use_model_types=True):
|
| 1146 |
+
if config_to_class is None and not use_model_types:
|
| 1147 |
+
raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
|
| 1148 |
+
if use_model_types:
|
| 1149 |
+
if config_to_class is None:
|
| 1150 |
+
model_type_to_name = {model_type: f"[`{config}`]" for model_type, config in CONFIG_MAPPING_NAMES.items()}
|
| 1151 |
+
else:
|
| 1152 |
+
model_type_to_name = {
|
| 1153 |
+
model_type: _get_class_name(model_class)
|
| 1154 |
+
for model_type, model_class in config_to_class.items()
|
| 1155 |
+
if model_type in MODEL_NAMES_MAPPING
|
| 1156 |
+
}
|
| 1157 |
+
lines = [
|
| 1158 |
+
f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
|
| 1159 |
+
for model_type in sorted(model_type_to_name.keys())
|
| 1160 |
+
]
|
| 1161 |
+
else:
|
| 1162 |
+
config_to_name = {
|
| 1163 |
+
CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
|
| 1164 |
+
for config, clas in config_to_class.items()
|
| 1165 |
+
if config in CONFIG_MAPPING_NAMES
|
| 1166 |
+
}
|
| 1167 |
+
config_to_model_name = {
|
| 1168 |
+
config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
|
| 1169 |
+
}
|
| 1170 |
+
lines = [
|
| 1171 |
+
f"{indent}- [`{config_name}`] configuration class:"
|
| 1172 |
+
f" {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
|
| 1173 |
+
for config_name in sorted(config_to_name.keys())
|
| 1174 |
+
]
|
| 1175 |
+
return "\n".join(lines)
|
| 1176 |
+
|
| 1177 |
+
|
| 1178 |
+
def replace_list_option_in_docstrings(
|
| 1179 |
+
config_to_class=None, use_model_types: bool = True
|
| 1180 |
+
) -> Callable[[_CallableT], _CallableT]:
|
| 1181 |
+
def docstring_decorator(fn):
|
| 1182 |
+
docstrings = fn.__doc__
|
| 1183 |
+
if docstrings is None:
|
| 1184 |
+
# Example: -OO
|
| 1185 |
+
return fn
|
| 1186 |
+
lines = docstrings.split("\n")
|
| 1187 |
+
i = 0
|
| 1188 |
+
while i < len(lines) and re.search(r"^(\s*)List options\s*$", lines[i]) is None:
|
| 1189 |
+
i += 1
|
| 1190 |
+
if i < len(lines):
|
| 1191 |
+
indent = re.search(r"^(\s*)List options\s*$", lines[i]).groups()[0]
|
| 1192 |
+
if use_model_types:
|
| 1193 |
+
indent = f"{indent} "
|
| 1194 |
+
lines[i] = _list_model_options(indent, config_to_class=config_to_class, use_model_types=use_model_types)
|
| 1195 |
+
docstrings = "\n".join(lines)
|
| 1196 |
+
else:
|
| 1197 |
+
raise ValueError(
|
| 1198 |
+
f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current"
|
| 1199 |
+
f" docstring is:\n{docstrings}"
|
| 1200 |
+
)
|
| 1201 |
+
fn.__doc__ = docstrings
|
| 1202 |
+
return fn
|
| 1203 |
+
|
| 1204 |
+
return docstring_decorator
|
| 1205 |
+
|
| 1206 |
+
|
| 1207 |
+
class AutoConfig:
|
| 1208 |
+
r"""
|
| 1209 |
+
This is a generic configuration class that will be instantiated as one of the configuration classes of the library
|
| 1210 |
+
when created with the [`~AutoConfig.from_pretrained`] class method.
|
| 1211 |
+
|
| 1212 |
+
This class cannot be instantiated directly using `__init__()` (throws an error).
|
| 1213 |
+
"""
|
| 1214 |
+
|
| 1215 |
+
def __init__(self) -> None:
|
| 1216 |
+
raise OSError(
|
| 1217 |
+
"AutoConfig is designed to be instantiated "
|
| 1218 |
+
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
|
| 1219 |
+
)
|
| 1220 |
+
|
| 1221 |
+
@classmethod
|
| 1222 |
+
def for_model(cls, model_type: str, *args, **kwargs) -> PretrainedConfig:
|
| 1223 |
+
if model_type in CONFIG_MAPPING:
|
| 1224 |
+
config_class = CONFIG_MAPPING[model_type]
|
| 1225 |
+
return config_class(*args, **kwargs)
|
| 1226 |
+
raise ValueError(
|
| 1227 |
+
f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}"
|
| 1228 |
+
)
|
| 1229 |
+
|
| 1230 |
+
@classmethod
|
| 1231 |
+
@replace_list_option_in_docstrings()
|
| 1232 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], **kwargs):
|
| 1233 |
+
r"""
|
| 1234 |
+
Instantiate one of the configuration classes of the library from a pretrained model configuration.
|
| 1235 |
+
|
| 1236 |
+
The configuration class to instantiate is selected based on the `model_type` property of the config object that
|
| 1237 |
+
is loaded, or when it's missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 1238 |
+
|
| 1239 |
+
List options
|
| 1240 |
+
|
| 1241 |
+
Args:
|
| 1242 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 1243 |
+
Can be either:
|
| 1244 |
+
|
| 1245 |
+
- A string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
| 1246 |
+
huggingface.co.
|
| 1247 |
+
- A path to a *directory* containing a configuration file saved using the
|
| 1248 |
+
[`~PretrainedConfig.save_pretrained`] method, or the [`~PreTrainedModel.save_pretrained`] method,
|
| 1249 |
+
e.g., `./my_model_directory/`.
|
| 1250 |
+
- A path or url to a saved configuration JSON *file*, e.g.,
|
| 1251 |
+
`./my_model_directory/configuration.json`.
|
| 1252 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 1253 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
| 1254 |
+
standard cache should not be used.
|
| 1255 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 1256 |
+
Whether or not to force the (re-)download the model weights and configuration files and override the
|
| 1257 |
+
cached versions if they exist.
|
| 1258 |
+
resume_download:
|
| 1259 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 1260 |
+
Will be removed in v5 of Transformers.
|
| 1261 |
+
proxies (`dict[str, str]`, *optional*):
|
| 1262 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 1263 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 1264 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 1265 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 1266 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 1267 |
+
identifier allowed by git.
|
| 1268 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 1269 |
+
If `False`, then this function returns just the final configuration object.
|
| 1270 |
+
|
| 1271 |
+
If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
|
| 1272 |
+
dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
|
| 1273 |
+
part of `kwargs` which has not been used to update `config` and is otherwise ignored.
|
| 1274 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 1275 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 1276 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 1277 |
+
execute code present on the Hub on your local machine.
|
| 1278 |
+
kwargs(additional keyword arguments, *optional*):
|
| 1279 |
+
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
|
| 1280 |
+
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
| 1281 |
+
by the `return_unused_kwargs` keyword parameter.
|
| 1282 |
+
|
| 1283 |
+
Examples:
|
| 1284 |
+
|
| 1285 |
+
```python
|
| 1286 |
+
>>> from transformers import AutoConfig
|
| 1287 |
+
|
| 1288 |
+
>>> # Download configuration from huggingface.co and cache.
|
| 1289 |
+
>>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased")
|
| 1290 |
+
|
| 1291 |
+
>>> # Download configuration from huggingface.co (user-uploaded) and cache.
|
| 1292 |
+
>>> config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased")
|
| 1293 |
+
|
| 1294 |
+
>>> # If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*).
|
| 1295 |
+
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/")
|
| 1296 |
+
|
| 1297 |
+
>>> # Load a specific configuration file.
|
| 1298 |
+
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json")
|
| 1299 |
+
|
| 1300 |
+
>>> # Change some config attributes when loading a pretrained config.
|
| 1301 |
+
>>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
|
| 1302 |
+
>>> config.output_attentions
|
| 1303 |
+
True
|
| 1304 |
+
|
| 1305 |
+
>>> config, unused_kwargs = AutoConfig.from_pretrained(
|
| 1306 |
+
... "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
|
| 1307 |
+
... )
|
| 1308 |
+
>>> config.output_attentions
|
| 1309 |
+
True
|
| 1310 |
+
|
| 1311 |
+
>>> unused_kwargs
|
| 1312 |
+
{'foo': False}
|
| 1313 |
+
```
|
| 1314 |
+
"""
|
| 1315 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 1316 |
+
if use_auth_token is not None:
|
| 1317 |
+
warnings.warn(
|
| 1318 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 1319 |
+
FutureWarning,
|
| 1320 |
+
)
|
| 1321 |
+
if kwargs.get("token") is not None:
|
| 1322 |
+
raise ValueError(
|
| 1323 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 1324 |
+
)
|
| 1325 |
+
kwargs["token"] = use_auth_token
|
| 1326 |
+
|
| 1327 |
+
kwargs["_from_auto"] = True
|
| 1328 |
+
kwargs["name_or_path"] = pretrained_model_name_or_path
|
| 1329 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 1330 |
+
code_revision = kwargs.pop("code_revision", None)
|
| 1331 |
+
|
| 1332 |
+
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 1333 |
+
has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]
|
| 1334 |
+
has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING
|
| 1335 |
+
if has_remote_code:
|
| 1336 |
+
class_ref = config_dict["auto_map"]["AutoConfig"]
|
| 1337 |
+
if "--" in class_ref:
|
| 1338 |
+
upstream_repo = class_ref.split("--")[0]
|
| 1339 |
+
else:
|
| 1340 |
+
upstream_repo = None
|
| 1341 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 1342 |
+
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
|
| 1343 |
+
)
|
| 1344 |
+
|
| 1345 |
+
if has_remote_code and trust_remote_code:
|
| 1346 |
+
config_class = get_class_from_dynamic_module(
|
| 1347 |
+
class_ref, pretrained_model_name_or_path, code_revision=code_revision, **kwargs
|
| 1348 |
+
)
|
| 1349 |
+
config_class.register_for_auto_class()
|
| 1350 |
+
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 1351 |
+
elif "model_type" in config_dict:
|
| 1352 |
+
# Apply heuristic: if model_type is mistral but layer_types is present, treat as ministral
|
| 1353 |
+
if config_dict["model_type"] == "mistral" and "layer_types" in config_dict:
|
| 1354 |
+
logger.info(
|
| 1355 |
+
"Detected mistral model with layer_types, treating as ministral for alternating attention compatibility. "
|
| 1356 |
+
)
|
| 1357 |
+
config_dict["model_type"] = "ministral"
|
| 1358 |
+
|
| 1359 |
+
try:
|
| 1360 |
+
config_class = CONFIG_MAPPING[config_dict["model_type"]]
|
| 1361 |
+
except KeyError:
|
| 1362 |
+
raise ValueError(
|
| 1363 |
+
f"The checkpoint you are trying to load has model type `{config_dict['model_type']}` "
|
| 1364 |
+
"but Transformers does not recognize this architecture. This could be because of an "
|
| 1365 |
+
"issue with the checkpoint, or because your version of Transformers is out of date.\n\n"
|
| 1366 |
+
"You can update Transformers with the command `pip install --upgrade transformers`. If this "
|
| 1367 |
+
"does not work, and the checkpoint is very new, then there may not be a release version "
|
| 1368 |
+
"that supports this model yet. In this case, you can get the most up-to-date code by installing "
|
| 1369 |
+
"Transformers from source with the command "
|
| 1370 |
+
"`pip install git+https://github.com/huggingface/transformers.git`"
|
| 1371 |
+
)
|
| 1372 |
+
return config_class.from_dict(config_dict, **unused_kwargs)
|
| 1373 |
+
else:
|
| 1374 |
+
# Fallback: use pattern matching on the string.
|
| 1375 |
+
# We go from longer names to shorter names to catch roberta before bert (for instance)
|
| 1376 |
+
for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True):
|
| 1377 |
+
if pattern in str(pretrained_model_name_or_path):
|
| 1378 |
+
return CONFIG_MAPPING[pattern].from_dict(config_dict, **unused_kwargs)
|
| 1379 |
+
|
| 1380 |
+
raise ValueError(
|
| 1381 |
+
f"Unrecognized model in {pretrained_model_name_or_path}. "
|
| 1382 |
+
f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
|
| 1383 |
+
f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
|
| 1384 |
+
)
|
| 1385 |
+
|
| 1386 |
+
@staticmethod
|
| 1387 |
+
def register(model_type, config, exist_ok=False) -> None:
|
| 1388 |
+
"""
|
| 1389 |
+
Register a new configuration for this class.
|
| 1390 |
+
|
| 1391 |
+
Args:
|
| 1392 |
+
model_type (`str`): The model type like "bert" or "gpt".
|
| 1393 |
+
config ([`PretrainedConfig`]): The config to register.
|
| 1394 |
+
"""
|
| 1395 |
+
if issubclass(config, PretrainedConfig) and config.model_type != model_type:
|
| 1396 |
+
raise ValueError(
|
| 1397 |
+
"The config you are passing has a `model_type` attribute that is not consistent with the model type "
|
| 1398 |
+
f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
|
| 1399 |
+
"match!"
|
| 1400 |
+
)
|
| 1401 |
+
CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok)
|
| 1402 |
+
|
| 1403 |
+
|
| 1404 |
+
__all__ = ["CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/feature_extraction_auto.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""AutoFeatureExtractor class."""
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import warnings
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
from typing import Optional, Union
|
| 23 |
+
|
| 24 |
+
# Build the list of all feature extractors
|
| 25 |
+
from ...configuration_utils import PretrainedConfig
|
| 26 |
+
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
| 27 |
+
from ...feature_extraction_utils import FeatureExtractionMixin
|
| 28 |
+
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, cached_file, logging
|
| 29 |
+
from .auto_factory import _LazyAutoMapping
|
| 30 |
+
from .configuration_auto import (
|
| 31 |
+
CONFIG_MAPPING_NAMES,
|
| 32 |
+
AutoConfig,
|
| 33 |
+
model_type_to_module_name,
|
| 34 |
+
replace_list_option_in_docstrings,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__)
|
| 39 |
+
|
| 40 |
+
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
| 41 |
+
[
|
| 42 |
+
("audio-spectrogram-transformer", "ASTFeatureExtractor"),
|
| 43 |
+
("beit", "BeitFeatureExtractor"),
|
| 44 |
+
("chinese_clip", "ChineseCLIPFeatureExtractor"),
|
| 45 |
+
("clap", "ClapFeatureExtractor"),
|
| 46 |
+
("clip", "CLIPFeatureExtractor"),
|
| 47 |
+
("clipseg", "ViTFeatureExtractor"),
|
| 48 |
+
("clvp", "ClvpFeatureExtractor"),
|
| 49 |
+
("conditional_detr", "ConditionalDetrFeatureExtractor"),
|
| 50 |
+
("convnext", "ConvNextFeatureExtractor"),
|
| 51 |
+
("cvt", "ConvNextFeatureExtractor"),
|
| 52 |
+
("dac", "DacFeatureExtractor"),
|
| 53 |
+
("data2vec-audio", "Wav2Vec2FeatureExtractor"),
|
| 54 |
+
("data2vec-vision", "BeitFeatureExtractor"),
|
| 55 |
+
("deformable_detr", "DeformableDetrFeatureExtractor"),
|
| 56 |
+
("deit", "DeiTFeatureExtractor"),
|
| 57 |
+
("detr", "DetrFeatureExtractor"),
|
| 58 |
+
("dia", "DiaFeatureExtractor"),
|
| 59 |
+
("dinat", "ViTFeatureExtractor"),
|
| 60 |
+
("donut-swin", "DonutFeatureExtractor"),
|
| 61 |
+
("dpt", "DPTFeatureExtractor"),
|
| 62 |
+
("encodec", "EncodecFeatureExtractor"),
|
| 63 |
+
("flava", "FlavaFeatureExtractor"),
|
| 64 |
+
("gemma3n", "Gemma3nAudioFeatureExtractor"),
|
| 65 |
+
("glpn", "GLPNFeatureExtractor"),
|
| 66 |
+
("granite_speech", "GraniteSpeechFeatureExtractor"),
|
| 67 |
+
("groupvit", "CLIPFeatureExtractor"),
|
| 68 |
+
("hubert", "Wav2Vec2FeatureExtractor"),
|
| 69 |
+
("imagegpt", "ImageGPTFeatureExtractor"),
|
| 70 |
+
("kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"),
|
| 71 |
+
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
|
| 72 |
+
("layoutlmv3", "LayoutLMv3FeatureExtractor"),
|
| 73 |
+
("levit", "LevitFeatureExtractor"),
|
| 74 |
+
("maskformer", "MaskFormerFeatureExtractor"),
|
| 75 |
+
("mctct", "MCTCTFeatureExtractor"),
|
| 76 |
+
("mimi", "EncodecFeatureExtractor"),
|
| 77 |
+
("mobilenet_v1", "MobileNetV1FeatureExtractor"),
|
| 78 |
+
("mobilenet_v2", "MobileNetV2FeatureExtractor"),
|
| 79 |
+
("mobilevit", "MobileViTFeatureExtractor"),
|
| 80 |
+
("moonshine", "Wav2Vec2FeatureExtractor"),
|
| 81 |
+
("moshi", "EncodecFeatureExtractor"),
|
| 82 |
+
("nat", "ViTFeatureExtractor"),
|
| 83 |
+
("owlvit", "OwlViTFeatureExtractor"),
|
| 84 |
+
("parakeet_ctc", "ParakeetFeatureExtractor"),
|
| 85 |
+
("parakeet_encoder", "ParakeetFeatureExtractor"),
|
| 86 |
+
("perceiver", "PerceiverFeatureExtractor"),
|
| 87 |
+
("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
|
| 88 |
+
("poolformer", "PoolFormerFeatureExtractor"),
|
| 89 |
+
("pop2piano", "Pop2PianoFeatureExtractor"),
|
| 90 |
+
("regnet", "ConvNextFeatureExtractor"),
|
| 91 |
+
("resnet", "ConvNextFeatureExtractor"),
|
| 92 |
+
("seamless_m4t", "SeamlessM4TFeatureExtractor"),
|
| 93 |
+
("seamless_m4t_v2", "SeamlessM4TFeatureExtractor"),
|
| 94 |
+
("segformer", "SegformerFeatureExtractor"),
|
| 95 |
+
("sew", "Wav2Vec2FeatureExtractor"),
|
| 96 |
+
("sew-d", "Wav2Vec2FeatureExtractor"),
|
| 97 |
+
("speech_to_text", "Speech2TextFeatureExtractor"),
|
| 98 |
+
("speecht5", "SpeechT5FeatureExtractor"),
|
| 99 |
+
("swiftformer", "ViTFeatureExtractor"),
|
| 100 |
+
("swin", "ViTFeatureExtractor"),
|
| 101 |
+
("swinv2", "ViTFeatureExtractor"),
|
| 102 |
+
("table-transformer", "DetrFeatureExtractor"),
|
| 103 |
+
("timesformer", "VideoMAEFeatureExtractor"),
|
| 104 |
+
("tvlt", "TvltFeatureExtractor"),
|
| 105 |
+
("unispeech", "Wav2Vec2FeatureExtractor"),
|
| 106 |
+
("unispeech-sat", "Wav2Vec2FeatureExtractor"),
|
| 107 |
+
("univnet", "UnivNetFeatureExtractor"),
|
| 108 |
+
("van", "ConvNextFeatureExtractor"),
|
| 109 |
+
("videomae", "VideoMAEFeatureExtractor"),
|
| 110 |
+
("vilt", "ViltFeatureExtractor"),
|
| 111 |
+
("vit", "ViTFeatureExtractor"),
|
| 112 |
+
("vit_mae", "ViTFeatureExtractor"),
|
| 113 |
+
("vit_msn", "ViTFeatureExtractor"),
|
| 114 |
+
("wav2vec2", "Wav2Vec2FeatureExtractor"),
|
| 115 |
+
("wav2vec2-bert", "Wav2Vec2FeatureExtractor"),
|
| 116 |
+
("wav2vec2-conformer", "Wav2Vec2FeatureExtractor"),
|
| 117 |
+
("wavlm", "Wav2Vec2FeatureExtractor"),
|
| 118 |
+
("whisper", "WhisperFeatureExtractor"),
|
| 119 |
+
("xclip", "CLIPFeatureExtractor"),
|
| 120 |
+
("xcodec", "DacFeatureExtractor"),
|
| 121 |
+
("yolos", "YolosFeatureExtractor"),
|
| 122 |
+
]
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def feature_extractor_class_from_name(class_name: str):
|
| 129 |
+
for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items():
|
| 130 |
+
if class_name in extractors:
|
| 131 |
+
module_name = model_type_to_module_name(module_name)
|
| 132 |
+
|
| 133 |
+
module = importlib.import_module(f".{module_name}", "transformers.models")
|
| 134 |
+
try:
|
| 135 |
+
return getattr(module, class_name)
|
| 136 |
+
except AttributeError:
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
for extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.values():
|
| 140 |
+
if getattr(extractor, "__name__", None) == class_name:
|
| 141 |
+
return extractor
|
| 142 |
+
|
| 143 |
+
# We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
|
| 144 |
+
# init and we return the proper dummy to get an appropriate error message.
|
| 145 |
+
main_module = importlib.import_module("transformers")
|
| 146 |
+
if hasattr(main_module, class_name):
|
| 147 |
+
return getattr(main_module, class_name)
|
| 148 |
+
|
| 149 |
+
return None
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_feature_extractor_config(
|
| 153 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 154 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 155 |
+
force_download: bool = False,
|
| 156 |
+
resume_download: Optional[bool] = None,
|
| 157 |
+
proxies: Optional[dict[str, str]] = None,
|
| 158 |
+
token: Optional[Union[bool, str]] = None,
|
| 159 |
+
revision: Optional[str] = None,
|
| 160 |
+
local_files_only: bool = False,
|
| 161 |
+
**kwargs,
|
| 162 |
+
):
|
| 163 |
+
"""
|
| 164 |
+
Loads the tokenizer configuration from a pretrained model tokenizer configuration.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 168 |
+
This can be either:
|
| 169 |
+
|
| 170 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
| 171 |
+
huggingface.co.
|
| 172 |
+
- a path to a *directory* containing a configuration file saved using the
|
| 173 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
| 174 |
+
|
| 175 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 176 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
| 177 |
+
cache should not be used.
|
| 178 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 179 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
| 180 |
+
exist.
|
| 181 |
+
resume_download:
|
| 182 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 183 |
+
Will be removed in v5 of Transformers.
|
| 184 |
+
proxies (`dict[str, str]`, *optional*):
|
| 185 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 186 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 187 |
+
token (`str` or *bool*, *optional*):
|
| 188 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 189 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 190 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 191 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 192 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 193 |
+
identifier allowed by git.
|
| 194 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 195 |
+
If `True`, will only try to load the tokenizer configuration from local files.
|
| 196 |
+
|
| 197 |
+
<Tip>
|
| 198 |
+
|
| 199 |
+
Passing `token=True` is required when you want to use a private model.
|
| 200 |
+
|
| 201 |
+
</Tip>
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
`Dict`: The configuration of the tokenizer.
|
| 205 |
+
|
| 206 |
+
Examples:
|
| 207 |
+
|
| 208 |
+
```python
|
| 209 |
+
# Download configuration from huggingface.co and cache.
|
| 210 |
+
tokenizer_config = get_tokenizer_config("google-bert/bert-base-uncased")
|
| 211 |
+
# This model does not have a tokenizer config so the result will be an empty dict.
|
| 212 |
+
tokenizer_config = get_tokenizer_config("FacebookAI/xlm-roberta-base")
|
| 213 |
+
|
| 214 |
+
# Save a pretrained tokenizer locally and you can reload its config
|
| 215 |
+
from transformers import AutoTokenizer
|
| 216 |
+
|
| 217 |
+
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
|
| 218 |
+
tokenizer.save_pretrained("tokenizer-test")
|
| 219 |
+
tokenizer_config = get_tokenizer_config("tokenizer-test")
|
| 220 |
+
```"""
|
| 221 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 222 |
+
if use_auth_token is not None:
|
| 223 |
+
warnings.warn(
|
| 224 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 225 |
+
FutureWarning,
|
| 226 |
+
)
|
| 227 |
+
if token is not None:
|
| 228 |
+
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
| 229 |
+
token = use_auth_token
|
| 230 |
+
|
| 231 |
+
resolved_config_file = cached_file(
|
| 232 |
+
pretrained_model_name_or_path,
|
| 233 |
+
FEATURE_EXTRACTOR_NAME,
|
| 234 |
+
cache_dir=cache_dir,
|
| 235 |
+
force_download=force_download,
|
| 236 |
+
resume_download=resume_download,
|
| 237 |
+
proxies=proxies,
|
| 238 |
+
token=token,
|
| 239 |
+
revision=revision,
|
| 240 |
+
local_files_only=local_files_only,
|
| 241 |
+
_raise_exceptions_for_gated_repo=False,
|
| 242 |
+
_raise_exceptions_for_missing_entries=False,
|
| 243 |
+
_raise_exceptions_for_connection_errors=False,
|
| 244 |
+
)
|
| 245 |
+
if resolved_config_file is None:
|
| 246 |
+
logger.info(
|
| 247 |
+
"Could not locate the feature extractor configuration file, will try to use the model config instead."
|
| 248 |
+
)
|
| 249 |
+
return {}
|
| 250 |
+
|
| 251 |
+
with open(resolved_config_file, encoding="utf-8") as reader:
|
| 252 |
+
return json.load(reader)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class AutoFeatureExtractor:
|
| 256 |
+
r"""
|
| 257 |
+
This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the
|
| 258 |
+
library when created with the [`AutoFeatureExtractor.from_pretrained`] class method.
|
| 259 |
+
|
| 260 |
+
This class cannot be instantiated directly using `__init__()` (throws an error).
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
def __init__(self):
|
| 264 |
+
raise OSError(
|
| 265 |
+
"AutoFeatureExtractor is designed to be instantiated "
|
| 266 |
+
"using the `AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)` method."
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
@classmethod
|
| 270 |
+
@replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING_NAMES)
|
| 271 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 272 |
+
r"""
|
| 273 |
+
Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary.
|
| 274 |
+
|
| 275 |
+
The feature extractor class to instantiate is selected based on the `model_type` property of the config object
|
| 276 |
+
(either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
|
| 277 |
+
missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 278 |
+
|
| 279 |
+
List options
|
| 280 |
+
|
| 281 |
+
Params:
|
| 282 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 283 |
+
This can be either:
|
| 284 |
+
|
| 285 |
+
- a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
|
| 286 |
+
huggingface.co.
|
| 287 |
+
- a path to a *directory* containing a feature extractor file saved using the
|
| 288 |
+
[`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g.,
|
| 289 |
+
`./my_model_directory/`.
|
| 290 |
+
- a path or url to a saved feature extractor JSON *file*, e.g.,
|
| 291 |
+
`./my_model_directory/preprocessor_config.json`.
|
| 292 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 293 |
+
Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
|
| 294 |
+
standard cache should not be used.
|
| 295 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 296 |
+
Whether or not to force to (re-)download the feature extractor files and override the cached versions
|
| 297 |
+
if they exist.
|
| 298 |
+
resume_download:
|
| 299 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 300 |
+
Will be removed in v5 of Transformers.
|
| 301 |
+
proxies (`dict[str, str]`, *optional*):
|
| 302 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 303 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 304 |
+
token (`str` or *bool*, *optional*):
|
| 305 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 306 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 307 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 308 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 309 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 310 |
+
identifier allowed by git.
|
| 311 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 312 |
+
If `False`, then this function returns just the final feature extractor object. If `True`, then this
|
| 313 |
+
functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
|
| 314 |
+
consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
|
| 315 |
+
`kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
|
| 316 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 317 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 318 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 319 |
+
execute code present on the Hub on your local machine.
|
| 320 |
+
kwargs (`dict[str, Any]`, *optional*):
|
| 321 |
+
The values in kwargs of any keys which are feature extractor attributes will be used to override the
|
| 322 |
+
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
|
| 323 |
+
controlled by the `return_unused_kwargs` keyword parameter.
|
| 324 |
+
|
| 325 |
+
<Tip>
|
| 326 |
+
|
| 327 |
+
Passing `token=True` is required when you want to use a private model.
|
| 328 |
+
|
| 329 |
+
</Tip>
|
| 330 |
+
|
| 331 |
+
Examples:
|
| 332 |
+
|
| 333 |
+
```python
|
| 334 |
+
>>> from transformers import AutoFeatureExtractor
|
| 335 |
+
|
| 336 |
+
>>> # Download feature extractor from huggingface.co and cache.
|
| 337 |
+
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
| 338 |
+
|
| 339 |
+
>>> # If feature extractor files are in a directory (e.g. feature extractor was saved using *save_pretrained('./test/saved_model/')*)
|
| 340 |
+
>>> # feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
|
| 341 |
+
```"""
|
| 342 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 343 |
+
if use_auth_token is not None:
|
| 344 |
+
warnings.warn(
|
| 345 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 346 |
+
FutureWarning,
|
| 347 |
+
)
|
| 348 |
+
if kwargs.get("token") is not None:
|
| 349 |
+
raise ValueError(
|
| 350 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 351 |
+
)
|
| 352 |
+
kwargs["token"] = use_auth_token
|
| 353 |
+
|
| 354 |
+
config = kwargs.pop("config", None)
|
| 355 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 356 |
+
kwargs["_from_auto"] = True
|
| 357 |
+
|
| 358 |
+
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
|
| 359 |
+
feature_extractor_class = config_dict.get("feature_extractor_type", None)
|
| 360 |
+
feature_extractor_auto_map = None
|
| 361 |
+
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
|
| 362 |
+
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
|
| 363 |
+
|
| 364 |
+
# If we don't find the feature extractor class in the feature extractor config, let's try the model config.
|
| 365 |
+
if feature_extractor_class is None and feature_extractor_auto_map is None:
|
| 366 |
+
if not isinstance(config, PretrainedConfig):
|
| 367 |
+
config = AutoConfig.from_pretrained(
|
| 368 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 369 |
+
)
|
| 370 |
+
# It could be in `config.feature_extractor_type``
|
| 371 |
+
feature_extractor_class = getattr(config, "feature_extractor_type", None)
|
| 372 |
+
if hasattr(config, "auto_map") and "AutoFeatureExtractor" in config.auto_map:
|
| 373 |
+
feature_extractor_auto_map = config.auto_map["AutoFeatureExtractor"]
|
| 374 |
+
|
| 375 |
+
if feature_extractor_class is not None:
|
| 376 |
+
feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)
|
| 377 |
+
|
| 378 |
+
has_remote_code = feature_extractor_auto_map is not None
|
| 379 |
+
has_local_code = feature_extractor_class is not None or type(config) in FEATURE_EXTRACTOR_MAPPING
|
| 380 |
+
if has_remote_code:
|
| 381 |
+
if "--" in feature_extractor_auto_map:
|
| 382 |
+
upstream_repo = feature_extractor_auto_map.split("--")[0]
|
| 383 |
+
else:
|
| 384 |
+
upstream_repo = None
|
| 385 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 386 |
+
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
if has_remote_code and trust_remote_code:
|
| 390 |
+
feature_extractor_class = get_class_from_dynamic_module(
|
| 391 |
+
feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs
|
| 392 |
+
)
|
| 393 |
+
_ = kwargs.pop("code_revision", None)
|
| 394 |
+
feature_extractor_class.register_for_auto_class()
|
| 395 |
+
return feature_extractor_class.from_dict(config_dict, **kwargs)
|
| 396 |
+
elif feature_extractor_class is not None:
|
| 397 |
+
return feature_extractor_class.from_dict(config_dict, **kwargs)
|
| 398 |
+
# Last try: we use the FEATURE_EXTRACTOR_MAPPING.
|
| 399 |
+
elif type(config) in FEATURE_EXTRACTOR_MAPPING:
|
| 400 |
+
feature_extractor_class = FEATURE_EXTRACTOR_MAPPING[type(config)]
|
| 401 |
+
return feature_extractor_class.from_dict(config_dict, **kwargs)
|
| 402 |
+
|
| 403 |
+
raise ValueError(
|
| 404 |
+
f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a "
|
| 405 |
+
f"`feature_extractor_type` key in its {FEATURE_EXTRACTOR_NAME} of {CONFIG_NAME}, or one of the following "
|
| 406 |
+
f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES)}"
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
@staticmethod
|
| 410 |
+
def register(config_class, feature_extractor_class, exist_ok=False):
|
| 411 |
+
"""
|
| 412 |
+
Register a new feature extractor for this class.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
config_class ([`PretrainedConfig`]):
|
| 416 |
+
The configuration corresponding to the model to register.
|
| 417 |
+
feature_extractor_class ([`FeatureExtractorMixin`]): The feature extractor to register.
|
| 418 |
+
"""
|
| 419 |
+
FEATURE_EXTRACTOR_MAPPING.register(config_class, feature_extractor_class, exist_ok=exist_ok)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
__all__ = ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/image_processing_auto.py
ADDED
|
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""AutoImageProcessor class."""
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import warnings
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
from typing import TYPE_CHECKING, Optional, Union
|
| 23 |
+
|
| 24 |
+
# Build the list of all image processors
|
| 25 |
+
from ...configuration_utils import PretrainedConfig
|
| 26 |
+
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
| 27 |
+
from ...image_processing_utils import ImageProcessingMixin
|
| 28 |
+
from ...image_processing_utils_fast import BaseImageProcessorFast
|
| 29 |
+
from ...utils import (
|
| 30 |
+
CONFIG_NAME,
|
| 31 |
+
IMAGE_PROCESSOR_NAME,
|
| 32 |
+
cached_file,
|
| 33 |
+
is_timm_config_dict,
|
| 34 |
+
is_timm_local_checkpoint,
|
| 35 |
+
is_torchvision_available,
|
| 36 |
+
is_vision_available,
|
| 37 |
+
logging,
|
| 38 |
+
)
|
| 39 |
+
from ...utils.import_utils import requires
|
| 40 |
+
from .auto_factory import _LazyAutoMapping
|
| 41 |
+
from .configuration_auto import (
|
| 42 |
+
CONFIG_MAPPING_NAMES,
|
| 43 |
+
AutoConfig,
|
| 44 |
+
model_type_to_module_name,
|
| 45 |
+
replace_list_option_in_docstrings,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
FORCE_FAST_IMAGE_PROCESSOR = ["Qwen2VLImageProcessor"]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if TYPE_CHECKING:
|
| 56 |
+
# This significantly improves completion suggestion performance when
|
| 57 |
+
# the transformers package is used with Microsoft's Pylance language server.
|
| 58 |
+
IMAGE_PROCESSOR_MAPPING_NAMES: OrderedDict[str, tuple[Optional[str], Optional[str]]] = OrderedDict()
|
| 59 |
+
else:
|
| 60 |
+
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
| 61 |
+
[
|
| 62 |
+
("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 63 |
+
("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 64 |
+
("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
| 65 |
+
("aria", ("AriaImageProcessor", None)),
|
| 66 |
+
("beit", ("BeitImageProcessor", "BeitImageProcessorFast")),
|
| 67 |
+
("bit", ("BitImageProcessor", "BitImageProcessorFast")),
|
| 68 |
+
("blip", ("BlipImageProcessor", "BlipImageProcessorFast")),
|
| 69 |
+
("blip-2", ("BlipImageProcessor", "BlipImageProcessorFast")),
|
| 70 |
+
("bridgetower", ("BridgeTowerImageProcessor", "BridgeTowerImageProcessorFast")),
|
| 71 |
+
("chameleon", ("ChameleonImageProcessor", "ChameleonImageProcessorFast")),
|
| 72 |
+
("chinese_clip", ("ChineseCLIPImageProcessor", "ChineseCLIPImageProcessorFast")),
|
| 73 |
+
("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 74 |
+
("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 75 |
+
("cohere2_vision", (None, "Cohere2VisionImageProcessorFast")),
|
| 76 |
+
("conditional_detr", ("ConditionalDetrImageProcessor", "ConditionalDetrImageProcessorFast")),
|
| 77 |
+
("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
| 78 |
+
("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
| 79 |
+
("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
| 80 |
+
("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")),
|
| 81 |
+
("deepseek_vl", ("DeepseekVLImageProcessor", "DeepseekVLImageProcessorFast")),
|
| 82 |
+
("deepseek_vl_hybrid", ("DeepseekVLHybridImageProcessor", "DeepseekVLHybridImageProcessorFast")),
|
| 83 |
+
("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
|
| 84 |
+
("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")),
|
| 85 |
+
("depth_anything", ("DPTImageProcessor", "DPTImageProcessorFast")),
|
| 86 |
+
("depth_pro", ("DepthProImageProcessor", "DepthProImageProcessorFast")),
|
| 87 |
+
("deta", ("DetaImageProcessor", None)),
|
| 88 |
+
("detr", ("DetrImageProcessor", "DetrImageProcessorFast")),
|
| 89 |
+
("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 90 |
+
("dinov2", ("BitImageProcessor", "BitImageProcessorFast")),
|
| 91 |
+
("dinov3_vit", (None, "DINOv3ViTImageProcessorFast")),
|
| 92 |
+
("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")),
|
| 93 |
+
("dpt", ("DPTImageProcessor", "DPTImageProcessorFast")),
|
| 94 |
+
("edgetam", (None, "Sam2ImageProcessorFast")),
|
| 95 |
+
("efficientformer", ("EfficientFormerImageProcessor", None)),
|
| 96 |
+
("efficientloftr", ("EfficientLoFTRImageProcessor", "EfficientLoFTRImageProcessorFast")),
|
| 97 |
+
("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
|
| 98 |
+
("eomt", ("EomtImageProcessor", "EomtImageProcessorFast")),
|
| 99 |
+
("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")),
|
| 100 |
+
("focalnet", ("BitImageProcessor", "BitImageProcessorFast")),
|
| 101 |
+
("fuyu", ("FuyuImageProcessor", None)),
|
| 102 |
+
("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
|
| 103 |
+
("gemma3n", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
| 104 |
+
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 105 |
+
("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")),
|
| 106 |
+
("glpn", ("GLPNImageProcessor", None)),
|
| 107 |
+
("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
|
| 108 |
+
("grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
|
| 109 |
+
("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 110 |
+
("hiera", ("BitImageProcessor", "BitImageProcessorFast")),
|
| 111 |
+
("idefics", ("IdeficsImageProcessor", None)),
|
| 112 |
+
("idefics2", ("Idefics2ImageProcessor", "Idefics2ImageProcessorFast")),
|
| 113 |
+
("idefics3", ("Idefics3ImageProcessor", "Idefics3ImageProcessorFast")),
|
| 114 |
+
("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 115 |
+
("imagegpt", ("ImageGPTImageProcessor", "ImageGPTImageProcessorFast")),
|
| 116 |
+
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
|
| 117 |
+
("instructblipvideo", ("InstructBlipVideoImageProcessor", None)),
|
| 118 |
+
("janus", ("JanusImageProcessor", "JanusImageProcessorFast")),
|
| 119 |
+
("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 120 |
+
("kosmos-2.5", ("Kosmos2_5ImageProcessor", "Kosmos2_5ImageProcessorFast")),
|
| 121 |
+
("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")),
|
| 122 |
+
("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
|
| 123 |
+
("levit", ("LevitImageProcessor", "LevitImageProcessorFast")),
|
| 124 |
+
("lfm2_vl", (None, "Lfm2VlImageProcessorFast")),
|
| 125 |
+
("lightglue", ("LightGlueImageProcessor", None)),
|
| 126 |
+
("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")),
|
| 127 |
+
("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
|
| 128 |
+
("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
|
| 129 |
+
("llava_next_video", ("LlavaNextVideoImageProcessor", None)),
|
| 130 |
+
("llava_onevision", ("LlavaOnevisionImageProcessor", "LlavaOnevisionImageProcessorFast")),
|
| 131 |
+
("mask2former", ("Mask2FormerImageProcessor", "Mask2FormerImageProcessorFast")),
|
| 132 |
+
("maskformer", ("MaskFormerImageProcessor", "MaskFormerImageProcessorFast")),
|
| 133 |
+
("metaclip_2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 134 |
+
("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 135 |
+
("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
|
| 136 |
+
("mlcd", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 137 |
+
("mllama", ("MllamaImageProcessor", None)),
|
| 138 |
+
("mm-grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
|
| 139 |
+
("mobilenet_v1", ("MobileNetV1ImageProcessor", "MobileNetV1ImageProcessorFast")),
|
| 140 |
+
("mobilenet_v2", ("MobileNetV2ImageProcessor", "MobileNetV2ImageProcessorFast")),
|
| 141 |
+
("mobilevit", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")),
|
| 142 |
+
("mobilevitv2", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")),
|
| 143 |
+
("nat", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 144 |
+
("nougat", ("NougatImageProcessor", "NougatImageProcessorFast")),
|
| 145 |
+
("oneformer", ("OneFormerImageProcessor", "OneFormerImageProcessorFast")),
|
| 146 |
+
("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")),
|
| 147 |
+
("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")),
|
| 148 |
+
("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
|
| 149 |
+
("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
| 150 |
+
("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")),
|
| 151 |
+
("perception_lm", (None, "PerceptionLMImageProcessorFast")),
|
| 152 |
+
("phi4_multimodal", (None, "Phi4MultimodalImageProcessorFast")),
|
| 153 |
+
("pix2struct", ("Pix2StructImageProcessor", None)),
|
| 154 |
+
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
|
| 155 |
+
("poolformer", ("PoolFormerImageProcessor", "PoolFormerImageProcessorFast")),
|
| 156 |
+
("prompt_depth_anything", ("PromptDepthAnythingImageProcessor", "PromptDepthAnythingImageProcessorFast")),
|
| 157 |
+
("pvt", ("PvtImageProcessor", "PvtImageProcessorFast")),
|
| 158 |
+
("pvt_v2", ("PvtImageProcessor", "PvtImageProcessorFast")),
|
| 159 |
+
("qwen2_5_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
| 160 |
+
("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
| 161 |
+
("qwen3_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
|
| 162 |
+
("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
| 163 |
+
("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
| 164 |
+
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
|
| 165 |
+
("sam", ("SamImageProcessor", "SamImageProcessorFast")),
|
| 166 |
+
("sam2", (None, "Sam2ImageProcessorFast")),
|
| 167 |
+
("sam_hq", ("SamImageProcessor", "SamImageProcessorFast")),
|
| 168 |
+
("segformer", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
|
| 169 |
+
("seggpt", ("SegGptImageProcessor", None)),
|
| 170 |
+
("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
|
| 171 |
+
("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
| 172 |
+
("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")),
|
| 173 |
+
("smolvlm", ("SmolVLMImageProcessor", "SmolVLMImageProcessorFast")),
|
| 174 |
+
("superglue", ("SuperGlueImageProcessor", None)),
|
| 175 |
+
("superpoint", ("SuperPointImageProcessor", "SuperPointImageProcessorFast")),
|
| 176 |
+
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 177 |
+
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 178 |
+
("swin2sr", ("Swin2SRImageProcessor", "Swin2SRImageProcessorFast")),
|
| 179 |
+
("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 180 |
+
("table-transformer", ("DetrImageProcessor", "DetrImageProcessorFast")),
|
| 181 |
+
("textnet", ("TextNetImageProcessor", "TextNetImageProcessorFast")),
|
| 182 |
+
("timesformer", ("VideoMAEImageProcessor", None)),
|
| 183 |
+
("timm_wrapper", ("TimmWrapperImageProcessor", None)),
|
| 184 |
+
("tvlt", ("TvltImageProcessor", None)),
|
| 185 |
+
("tvp", ("TvpImageProcessor", "TvpImageProcessorFast")),
|
| 186 |
+
("udop", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
|
| 187 |
+
("upernet", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
|
| 188 |
+
("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
|
| 189 |
+
("videomae", ("VideoMAEImageProcessor", None)),
|
| 190 |
+
("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")),
|
| 191 |
+
("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 192 |
+
("vit", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 193 |
+
("vit_hybrid", ("ViTHybridImageProcessor", None)),
|
| 194 |
+
("vit_mae", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 195 |
+
("vit_msn", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
| 196 |
+
("vitmatte", ("VitMatteImageProcessor", "VitMatteImageProcessorFast")),
|
| 197 |
+
("xclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
| 198 |
+
("yolos", ("YolosImageProcessor", "YolosImageProcessorFast")),
|
| 199 |
+
("zoedepth", ("ZoeDepthImageProcessor", "ZoeDepthImageProcessorFast")),
|
| 200 |
+
]
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Override to None if the packages are not available
|
| 204 |
+
for model_type, (slow_class, fast_class) in IMAGE_PROCESSOR_MAPPING_NAMES.items():
|
| 205 |
+
if not is_vision_available():
|
| 206 |
+
slow_class = None
|
| 207 |
+
if not is_torchvision_available():
|
| 208 |
+
fast_class = None
|
| 209 |
+
|
| 210 |
+
IMAGE_PROCESSOR_MAPPING_NAMES[model_type] = (slow_class, fast_class)
|
| 211 |
+
|
| 212 |
+
IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def get_image_processor_class_from_name(class_name: str):
|
| 216 |
+
if class_name == "BaseImageProcessorFast":
|
| 217 |
+
return BaseImageProcessorFast
|
| 218 |
+
|
| 219 |
+
for module_name, extractors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
|
| 220 |
+
if class_name in extractors:
|
| 221 |
+
module_name = model_type_to_module_name(module_name)
|
| 222 |
+
|
| 223 |
+
module = importlib.import_module(f".{module_name}", "transformers.models")
|
| 224 |
+
try:
|
| 225 |
+
return getattr(module, class_name)
|
| 226 |
+
except AttributeError:
|
| 227 |
+
continue
|
| 228 |
+
|
| 229 |
+
for extractors in IMAGE_PROCESSOR_MAPPING._extra_content.values():
|
| 230 |
+
for extractor in extractors:
|
| 231 |
+
if getattr(extractor, "__name__", None) == class_name:
|
| 232 |
+
return extractor
|
| 233 |
+
|
| 234 |
+
# We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main
|
| 235 |
+
# init and we return the proper dummy to get an appropriate error message.
|
| 236 |
+
main_module = importlib.import_module("transformers")
|
| 237 |
+
if hasattr(main_module, class_name):
|
| 238 |
+
return getattr(main_module, class_name)
|
| 239 |
+
|
| 240 |
+
return None
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def get_image_processor_config(
|
| 244 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 245 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 246 |
+
force_download: bool = False,
|
| 247 |
+
resume_download: Optional[bool] = None,
|
| 248 |
+
proxies: Optional[dict[str, str]] = None,
|
| 249 |
+
token: Optional[Union[bool, str]] = None,
|
| 250 |
+
revision: Optional[str] = None,
|
| 251 |
+
local_files_only: bool = False,
|
| 252 |
+
**kwargs,
|
| 253 |
+
):
|
| 254 |
+
"""
|
| 255 |
+
Loads the image processor configuration from a pretrained model image processor configuration.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 259 |
+
This can be either:
|
| 260 |
+
|
| 261 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
| 262 |
+
huggingface.co.
|
| 263 |
+
- a path to a *directory* containing a configuration file saved using the
|
| 264 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
| 265 |
+
|
| 266 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 267 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
| 268 |
+
cache should not be used.
|
| 269 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 270 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
| 271 |
+
exist.
|
| 272 |
+
resume_download:
|
| 273 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 274 |
+
Will be removed in v5 of Transformers.
|
| 275 |
+
proxies (`dict[str, str]`, *optional*):
|
| 276 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 277 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 278 |
+
token (`str` or *bool*, *optional*):
|
| 279 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 280 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 281 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 282 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 283 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 284 |
+
identifier allowed by git.
|
| 285 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 286 |
+
If `True`, will only try to load the image processor configuration from local files.
|
| 287 |
+
|
| 288 |
+
<Tip>
|
| 289 |
+
|
| 290 |
+
Passing `token=True` is required when you want to use a private model.
|
| 291 |
+
|
| 292 |
+
</Tip>
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
`Dict`: The configuration of the image processor.
|
| 296 |
+
|
| 297 |
+
Examples:
|
| 298 |
+
|
| 299 |
+
```python
|
| 300 |
+
# Download configuration from huggingface.co and cache.
|
| 301 |
+
image_processor_config = get_image_processor_config("google-bert/bert-base-uncased")
|
| 302 |
+
# This model does not have a image processor config so the result will be an empty dict.
|
| 303 |
+
image_processor_config = get_image_processor_config("FacebookAI/xlm-roberta-base")
|
| 304 |
+
|
| 305 |
+
# Save a pretrained image processor locally and you can reload its config
|
| 306 |
+
from transformers import AutoTokenizer
|
| 307 |
+
|
| 308 |
+
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
|
| 309 |
+
image_processor.save_pretrained("image-processor-test")
|
| 310 |
+
image_processor_config = get_image_processor_config("image-processor-test")
|
| 311 |
+
```"""
|
| 312 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 313 |
+
if use_auth_token is not None:
|
| 314 |
+
warnings.warn(
|
| 315 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 316 |
+
FutureWarning,
|
| 317 |
+
)
|
| 318 |
+
if token is not None:
|
| 319 |
+
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
| 320 |
+
token = use_auth_token
|
| 321 |
+
|
| 322 |
+
resolved_config_file = cached_file(
|
| 323 |
+
pretrained_model_name_or_path,
|
| 324 |
+
IMAGE_PROCESSOR_NAME,
|
| 325 |
+
cache_dir=cache_dir,
|
| 326 |
+
force_download=force_download,
|
| 327 |
+
resume_download=resume_download,
|
| 328 |
+
proxies=proxies,
|
| 329 |
+
token=token,
|
| 330 |
+
revision=revision,
|
| 331 |
+
local_files_only=local_files_only,
|
| 332 |
+
_raise_exceptions_for_gated_repo=False,
|
| 333 |
+
_raise_exceptions_for_missing_entries=False,
|
| 334 |
+
_raise_exceptions_for_connection_errors=False,
|
| 335 |
+
)
|
| 336 |
+
if resolved_config_file is None:
|
| 337 |
+
logger.info(
|
| 338 |
+
"Could not locate the image processor configuration file, will try to use the model config instead."
|
| 339 |
+
)
|
| 340 |
+
return {}
|
| 341 |
+
|
| 342 |
+
with open(resolved_config_file, encoding="utf-8") as reader:
|
| 343 |
+
return json.load(reader)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def _warning_fast_image_processor_available(fast_class):
|
| 347 |
+
logger.warning(
|
| 348 |
+
f"Fast image processor class {fast_class} is available for this model. "
|
| 349 |
+
"Using slow image processor class. To use the fast image processor class set `use_fast=True`."
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
@requires(backends=("vision",))
|
| 354 |
+
class AutoImageProcessor:
|
| 355 |
+
r"""
|
| 356 |
+
This is a generic image processor class that will be instantiated as one of the image processor classes of the
|
| 357 |
+
library when created with the [`AutoImageProcessor.from_pretrained`] class method.
|
| 358 |
+
|
| 359 |
+
This class cannot be instantiated directly using `__init__()` (throws an error).
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
def __init__(self):
|
| 363 |
+
raise OSError(
|
| 364 |
+
"AutoImageProcessor is designed to be instantiated "
|
| 365 |
+
"using the `AutoImageProcessor.from_pretrained(pretrained_model_name_or_path)` method."
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
@classmethod
|
| 369 |
+
@replace_list_option_in_docstrings(IMAGE_PROCESSOR_MAPPING_NAMES)
|
| 370 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
| 371 |
+
r"""
|
| 372 |
+
Instantiate one of the image processor classes of the library from a pretrained model vocabulary.
|
| 373 |
+
|
| 374 |
+
The image processor class to instantiate is selected based on the `model_type` property of the config object
|
| 375 |
+
(either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
|
| 376 |
+
missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 377 |
+
|
| 378 |
+
List options
|
| 379 |
+
|
| 380 |
+
Params:
|
| 381 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 382 |
+
This can be either:
|
| 383 |
+
|
| 384 |
+
- a string, the *model id* of a pretrained image_processor hosted inside a model repo on
|
| 385 |
+
huggingface.co.
|
| 386 |
+
- a path to a *directory* containing a image processor file saved using the
|
| 387 |
+
[`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g.,
|
| 388 |
+
`./my_model_directory/`.
|
| 389 |
+
- a path or url to a saved image processor JSON *file*, e.g.,
|
| 390 |
+
`./my_model_directory/preprocessor_config.json`.
|
| 391 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 392 |
+
Path to a directory in which a downloaded pretrained model image processor should be cached if the
|
| 393 |
+
standard cache should not be used.
|
| 394 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 395 |
+
Whether or not to force to (re-)download the image processor files and override the cached versions if
|
| 396 |
+
they exist.
|
| 397 |
+
resume_download:
|
| 398 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 399 |
+
Will be removed in v5 of Transformers.
|
| 400 |
+
proxies (`dict[str, str]`, *optional*):
|
| 401 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 402 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 403 |
+
token (`str` or *bool*, *optional*):
|
| 404 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 405 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 406 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 407 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 408 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 409 |
+
identifier allowed by git.
|
| 410 |
+
use_fast (`bool`, *optional*, defaults to `False`):
|
| 411 |
+
Use a fast torchvision-base image processor if it is supported for a given model.
|
| 412 |
+
If a fast image processor is not available for a given model, a normal numpy-based image processor
|
| 413 |
+
is returned instead.
|
| 414 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 415 |
+
If `False`, then this function returns just the final image processor object. If `True`, then this
|
| 416 |
+
functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
|
| 417 |
+
consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
|
| 418 |
+
`kwargs` which has not been used to update `image_processor` and is otherwise ignored.
|
| 419 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 420 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 421 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 422 |
+
execute code present on the Hub on your local machine.
|
| 423 |
+
image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
|
| 424 |
+
The name of the file in the model directory to use for the image processor config.
|
| 425 |
+
kwargs (`dict[str, Any]`, *optional*):
|
| 426 |
+
The values in kwargs of any keys which are image processor attributes will be used to override the
|
| 427 |
+
loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
|
| 428 |
+
controlled by the `return_unused_kwargs` keyword parameter.
|
| 429 |
+
|
| 430 |
+
<Tip>
|
| 431 |
+
|
| 432 |
+
Passing `token=True` is required when you want to use a private model.
|
| 433 |
+
|
| 434 |
+
</Tip>
|
| 435 |
+
|
| 436 |
+
Examples:
|
| 437 |
+
|
| 438 |
+
```python
|
| 439 |
+
>>> from transformers import AutoImageProcessor
|
| 440 |
+
|
| 441 |
+
>>> # Download image processor from huggingface.co and cache.
|
| 442 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
|
| 443 |
+
|
| 444 |
+
>>> # If image processor files are in a directory (e.g. image processor was saved using *save_pretrained('./test/saved_model/')*)
|
| 445 |
+
>>> # image_processor = AutoImageProcessor.from_pretrained("./test/saved_model/")
|
| 446 |
+
```"""
|
| 447 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 448 |
+
if use_auth_token is not None:
|
| 449 |
+
warnings.warn(
|
| 450 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 451 |
+
FutureWarning,
|
| 452 |
+
)
|
| 453 |
+
if kwargs.get("token") is not None:
|
| 454 |
+
raise ValueError(
|
| 455 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 456 |
+
)
|
| 457 |
+
kwargs["token"] = use_auth_token
|
| 458 |
+
|
| 459 |
+
config = kwargs.pop("config", None)
|
| 460 |
+
# TODO: @yoni, change in v4.48 (use_fast set to True by default)
|
| 461 |
+
use_fast = kwargs.pop("use_fast", None)
|
| 462 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 463 |
+
kwargs["_from_auto"] = True
|
| 464 |
+
|
| 465 |
+
# Resolve the image processor config filename
|
| 466 |
+
if "image_processor_filename" in kwargs:
|
| 467 |
+
image_processor_filename = kwargs.pop("image_processor_filename")
|
| 468 |
+
elif is_timm_local_checkpoint(pretrained_model_name_or_path):
|
| 469 |
+
image_processor_filename = CONFIG_NAME
|
| 470 |
+
else:
|
| 471 |
+
image_processor_filename = IMAGE_PROCESSOR_NAME
|
| 472 |
+
|
| 473 |
+
# Load the image processor config
|
| 474 |
+
try:
|
| 475 |
+
# Main path for all transformers models and local TimmWrapper checkpoints
|
| 476 |
+
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
|
| 477 |
+
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
|
| 478 |
+
)
|
| 479 |
+
except Exception as initial_exception:
|
| 480 |
+
# Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json`
|
| 481 |
+
# instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information
|
| 482 |
+
# except the model name, the only way to check if a remote checkpoint is a timm model is to try to
|
| 483 |
+
# load `config.json` and if it fails with some error, we raise the initial exception.
|
| 484 |
+
try:
|
| 485 |
+
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
|
| 486 |
+
pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs
|
| 487 |
+
)
|
| 488 |
+
except Exception:
|
| 489 |
+
raise initial_exception
|
| 490 |
+
|
| 491 |
+
# In case we have a config_dict, but it's not a timm config dict, we raise the initial exception,
|
| 492 |
+
# because only timm models have image processing in `config.json`.
|
| 493 |
+
if not is_timm_config_dict(config_dict):
|
| 494 |
+
raise initial_exception
|
| 495 |
+
|
| 496 |
+
image_processor_type = config_dict.get("image_processor_type", None)
|
| 497 |
+
image_processor_auto_map = None
|
| 498 |
+
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
|
| 499 |
+
image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
|
| 500 |
+
|
| 501 |
+
# If we still don't have the image processor class, check if we're loading from a previous feature extractor config
|
| 502 |
+
# and if so, infer the image processor class from there.
|
| 503 |
+
if image_processor_type is None and image_processor_auto_map is None:
|
| 504 |
+
feature_extractor_class = config_dict.pop("feature_extractor_type", None)
|
| 505 |
+
if feature_extractor_class is not None:
|
| 506 |
+
image_processor_type = feature_extractor_class.replace("FeatureExtractor", "ImageProcessor")
|
| 507 |
+
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
|
| 508 |
+
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
|
| 509 |
+
image_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "ImageProcessor")
|
| 510 |
+
|
| 511 |
+
# If we don't find the image processor class in the image processor config, let's try the model config.
|
| 512 |
+
if image_processor_type is None and image_processor_auto_map is None:
|
| 513 |
+
if not isinstance(config, PretrainedConfig):
|
| 514 |
+
config = AutoConfig.from_pretrained(
|
| 515 |
+
pretrained_model_name_or_path,
|
| 516 |
+
trust_remote_code=trust_remote_code,
|
| 517 |
+
**kwargs,
|
| 518 |
+
)
|
| 519 |
+
# It could be in `config.image_processor_type``
|
| 520 |
+
image_processor_type = getattr(config, "image_processor_type", None)
|
| 521 |
+
if hasattr(config, "auto_map") and "AutoImageProcessor" in config.auto_map:
|
| 522 |
+
image_processor_auto_map = config.auto_map["AutoImageProcessor"]
|
| 523 |
+
|
| 524 |
+
image_processor_class = None
|
| 525 |
+
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
|
| 526 |
+
if image_processor_type is not None:
|
| 527 |
+
# if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor.
|
| 528 |
+
if use_fast is None:
|
| 529 |
+
use_fast = image_processor_type.endswith("Fast")
|
| 530 |
+
if not use_fast and image_processor_type in FORCE_FAST_IMAGE_PROCESSOR and is_torchvision_available():
|
| 531 |
+
use_fast = True
|
| 532 |
+
logger.warning_once(
|
| 533 |
+
f"The image processor of type `{image_processor_type}` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. "
|
| 534 |
+
"This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. "
|
| 535 |
+
"Note that this behavior will be extended to all models in a future release."
|
| 536 |
+
)
|
| 537 |
+
if not use_fast:
|
| 538 |
+
logger.warning_once(
|
| 539 |
+
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
|
| 540 |
+
"`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
|
| 541 |
+
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
|
| 542 |
+
)
|
| 543 |
+
if use_fast and not image_processor_type.endswith("Fast"):
|
| 544 |
+
image_processor_type += "Fast"
|
| 545 |
+
if use_fast and not is_torchvision_available():
|
| 546 |
+
# check if there is a slow image processor class to fallback to
|
| 547 |
+
image_processor_class = get_image_processor_class_from_name(image_processor_type[:-4])
|
| 548 |
+
if image_processor_class is None:
|
| 549 |
+
raise ValueError(
|
| 550 |
+
f"`{image_processor_type}` requires `torchvision` to be installed. Please install `torchvision` and try again."
|
| 551 |
+
)
|
| 552 |
+
logger.warning_once(
|
| 553 |
+
"Using `use_fast=True` but `torchvision` is not available. Falling back to the slow image processor."
|
| 554 |
+
)
|
| 555 |
+
use_fast = False
|
| 556 |
+
if use_fast:
|
| 557 |
+
for image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.values():
|
| 558 |
+
if image_processor_type in image_processors:
|
| 559 |
+
break
|
| 560 |
+
else:
|
| 561 |
+
image_processor_type = image_processor_type[:-4]
|
| 562 |
+
use_fast = False
|
| 563 |
+
logger.warning_once(
|
| 564 |
+
"`use_fast` is set to `True` but the image processor class does not have a fast version. "
|
| 565 |
+
" Falling back to the slow version."
|
| 566 |
+
)
|
| 567 |
+
image_processor_class = get_image_processor_class_from_name(image_processor_type)
|
| 568 |
+
else:
|
| 569 |
+
image_processor_type_slow = image_processor_type.removesuffix("Fast")
|
| 570 |
+
image_processor_class = get_image_processor_class_from_name(image_processor_type_slow)
|
| 571 |
+
if image_processor_class is None and image_processor_type.endswith("Fast"):
|
| 572 |
+
raise ValueError(
|
| 573 |
+
f"`{image_processor_type}` does not have a slow version. Please set `use_fast=True` when instantiating the processor."
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
has_remote_code = image_processor_auto_map is not None
|
| 577 |
+
has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING
|
| 578 |
+
if has_remote_code:
|
| 579 |
+
if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple):
|
| 580 |
+
# In some configs, only the slow image processor class is stored
|
| 581 |
+
image_processor_auto_map = (image_processor_auto_map, None)
|
| 582 |
+
if use_fast and image_processor_auto_map[1] is not None:
|
| 583 |
+
class_ref = image_processor_auto_map[1]
|
| 584 |
+
else:
|
| 585 |
+
class_ref = image_processor_auto_map[0]
|
| 586 |
+
if "--" in class_ref:
|
| 587 |
+
upstream_repo = class_ref.split("--")[0]
|
| 588 |
+
else:
|
| 589 |
+
upstream_repo = None
|
| 590 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 591 |
+
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
if has_remote_code and trust_remote_code:
|
| 595 |
+
if not use_fast and image_processor_auto_map[1] is not None:
|
| 596 |
+
_warning_fast_image_processor_available(image_processor_auto_map[1])
|
| 597 |
+
|
| 598 |
+
image_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
|
| 599 |
+
_ = kwargs.pop("code_revision", None)
|
| 600 |
+
image_processor_class.register_for_auto_class()
|
| 601 |
+
return image_processor_class.from_dict(config_dict, **kwargs)
|
| 602 |
+
elif image_processor_class is not None:
|
| 603 |
+
return image_processor_class.from_dict(config_dict, **kwargs)
|
| 604 |
+
# Last try: we use the IMAGE_PROCESSOR_MAPPING.
|
| 605 |
+
elif type(config) in IMAGE_PROCESSOR_MAPPING:
|
| 606 |
+
image_processor_tuple = IMAGE_PROCESSOR_MAPPING[type(config)]
|
| 607 |
+
|
| 608 |
+
image_processor_class_py, image_processor_class_fast = image_processor_tuple
|
| 609 |
+
|
| 610 |
+
if not use_fast and image_processor_class_fast is not None:
|
| 611 |
+
_warning_fast_image_processor_available(image_processor_class_fast)
|
| 612 |
+
|
| 613 |
+
if image_processor_class_fast and (use_fast or image_processor_class_py is None):
|
| 614 |
+
return image_processor_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
| 615 |
+
else:
|
| 616 |
+
if image_processor_class_py is not None:
|
| 617 |
+
return image_processor_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
| 618 |
+
else:
|
| 619 |
+
raise ValueError(
|
| 620 |
+
"This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
|
| 621 |
+
)
|
| 622 |
+
raise ValueError(
|
| 623 |
+
f"Unrecognized image processor in {pretrained_model_name_or_path}. Should have a "
|
| 624 |
+
f"`image_processor_type` key in its {IMAGE_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
|
| 625 |
+
f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in IMAGE_PROCESSOR_MAPPING_NAMES)}"
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
@staticmethod
|
| 629 |
+
def register(
|
| 630 |
+
config_class,
|
| 631 |
+
image_processor_class=None,
|
| 632 |
+
slow_image_processor_class=None,
|
| 633 |
+
fast_image_processor_class=None,
|
| 634 |
+
exist_ok=False,
|
| 635 |
+
):
|
| 636 |
+
"""
|
| 637 |
+
Register a new image processor for this class.
|
| 638 |
+
|
| 639 |
+
Args:
|
| 640 |
+
config_class ([`PretrainedConfig`]):
|
| 641 |
+
The configuration corresponding to the model to register.
|
| 642 |
+
image_processor_class ([`ImageProcessingMixin`]): The image processor to register.
|
| 643 |
+
"""
|
| 644 |
+
if image_processor_class is not None:
|
| 645 |
+
if slow_image_processor_class is not None:
|
| 646 |
+
raise ValueError("Cannot specify both image_processor_class and slow_image_processor_class")
|
| 647 |
+
warnings.warn(
|
| 648 |
+
"The image_processor_class argument is deprecated and will be removed in v4.42. Please use `slow_image_processor_class`, or `fast_image_processor_class` instead",
|
| 649 |
+
FutureWarning,
|
| 650 |
+
)
|
| 651 |
+
slow_image_processor_class = image_processor_class
|
| 652 |
+
|
| 653 |
+
if slow_image_processor_class is None and fast_image_processor_class is None:
|
| 654 |
+
raise ValueError("You need to specify either slow_image_processor_class or fast_image_processor_class")
|
| 655 |
+
if slow_image_processor_class is not None and issubclass(slow_image_processor_class, BaseImageProcessorFast):
|
| 656 |
+
raise ValueError("You passed a fast image processor in as the `slow_image_processor_class`.")
|
| 657 |
+
if fast_image_processor_class is not None and not issubclass(
|
| 658 |
+
fast_image_processor_class, BaseImageProcessorFast
|
| 659 |
+
):
|
| 660 |
+
raise ValueError("The `fast_image_processor_class` should inherit from `BaseImageProcessorFast`.")
|
| 661 |
+
|
| 662 |
+
if (
|
| 663 |
+
slow_image_processor_class is not None
|
| 664 |
+
and fast_image_processor_class is not None
|
| 665 |
+
and issubclass(fast_image_processor_class, BaseImageProcessorFast)
|
| 666 |
+
and fast_image_processor_class.slow_image_processor_class != slow_image_processor_class
|
| 667 |
+
):
|
| 668 |
+
raise ValueError(
|
| 669 |
+
"The fast processor class you are passing has a `slow_image_processor_class` attribute that is not "
|
| 670 |
+
"consistent with the slow processor class you passed (fast tokenizer has "
|
| 671 |
+
f"{fast_image_processor_class.slow_image_processor_class} and you passed {slow_image_processor_class}. Fix one of those "
|
| 672 |
+
"so they match!"
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
# Avoid resetting a set slow/fast image processor if we are passing just the other ones.
|
| 676 |
+
if config_class in IMAGE_PROCESSOR_MAPPING._extra_content:
|
| 677 |
+
existing_slow, existing_fast = IMAGE_PROCESSOR_MAPPING[config_class]
|
| 678 |
+
if slow_image_processor_class is None:
|
| 679 |
+
slow_image_processor_class = existing_slow
|
| 680 |
+
if fast_image_processor_class is None:
|
| 681 |
+
fast_image_processor_class = existing_fast
|
| 682 |
+
|
| 683 |
+
IMAGE_PROCESSOR_MAPPING.register(
|
| 684 |
+
config_class, (slow_image_processor_class, fast_image_processor_class), exist_ok=exist_ok
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
__all__ = ["IMAGE_PROCESSOR_MAPPING", "AutoImageProcessor"]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_auto.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_flax_auto.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Auto Model class."""
|
| 16 |
+
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
|
| 19 |
+
from ...utils import logging
|
| 20 |
+
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
|
| 21 |
+
from .configuration_auto import CONFIG_MAPPING_NAMES
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
FLAX_MODEL_MAPPING_NAMES = OrderedDict(
|
| 28 |
+
[
|
| 29 |
+
# Base model mapping
|
| 30 |
+
("albert", "FlaxAlbertModel"),
|
| 31 |
+
("bart", "FlaxBartModel"),
|
| 32 |
+
("beit", "FlaxBeitModel"),
|
| 33 |
+
("bert", "FlaxBertModel"),
|
| 34 |
+
("big_bird", "FlaxBigBirdModel"),
|
| 35 |
+
("blenderbot", "FlaxBlenderbotModel"),
|
| 36 |
+
("blenderbot-small", "FlaxBlenderbotSmallModel"),
|
| 37 |
+
("bloom", "FlaxBloomModel"),
|
| 38 |
+
("clip", "FlaxCLIPModel"),
|
| 39 |
+
("dinov2", "FlaxDinov2Model"),
|
| 40 |
+
("distilbert", "FlaxDistilBertModel"),
|
| 41 |
+
("electra", "FlaxElectraModel"),
|
| 42 |
+
("gemma", "FlaxGemmaModel"),
|
| 43 |
+
("gpt-sw3", "FlaxGPT2Model"),
|
| 44 |
+
("gpt2", "FlaxGPT2Model"),
|
| 45 |
+
("gpt_neo", "FlaxGPTNeoModel"),
|
| 46 |
+
("gptj", "FlaxGPTJModel"),
|
| 47 |
+
("llama", "FlaxLlamaModel"),
|
| 48 |
+
("longt5", "FlaxLongT5Model"),
|
| 49 |
+
("marian", "FlaxMarianModel"),
|
| 50 |
+
("mbart", "FlaxMBartModel"),
|
| 51 |
+
("mistral", "FlaxMistralModel"),
|
| 52 |
+
("mt5", "FlaxMT5Model"),
|
| 53 |
+
("opt", "FlaxOPTModel"),
|
| 54 |
+
("pegasus", "FlaxPegasusModel"),
|
| 55 |
+
("regnet", "FlaxRegNetModel"),
|
| 56 |
+
("resnet", "FlaxResNetModel"),
|
| 57 |
+
("roberta", "FlaxRobertaModel"),
|
| 58 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"),
|
| 59 |
+
("roformer", "FlaxRoFormerModel"),
|
| 60 |
+
("t5", "FlaxT5Model"),
|
| 61 |
+
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
|
| 62 |
+
("vit", "FlaxViTModel"),
|
| 63 |
+
("wav2vec2", "FlaxWav2Vec2Model"),
|
| 64 |
+
("whisper", "FlaxWhisperModel"),
|
| 65 |
+
("xglm", "FlaxXGLMModel"),
|
| 66 |
+
("xlm-roberta", "FlaxXLMRobertaModel"),
|
| 67 |
+
]
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
| 71 |
+
[
|
| 72 |
+
# Model for pre-training mapping
|
| 73 |
+
("albert", "FlaxAlbertForPreTraining"),
|
| 74 |
+
("bart", "FlaxBartForConditionalGeneration"),
|
| 75 |
+
("bert", "FlaxBertForPreTraining"),
|
| 76 |
+
("big_bird", "FlaxBigBirdForPreTraining"),
|
| 77 |
+
("electra", "FlaxElectraForPreTraining"),
|
| 78 |
+
("longt5", "FlaxLongT5ForConditionalGeneration"),
|
| 79 |
+
("mbart", "FlaxMBartForConditionalGeneration"),
|
| 80 |
+
("mt5", "FlaxMT5ForConditionalGeneration"),
|
| 81 |
+
("roberta", "FlaxRobertaForMaskedLM"),
|
| 82 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
|
| 83 |
+
("roformer", "FlaxRoFormerForMaskedLM"),
|
| 84 |
+
("t5", "FlaxT5ForConditionalGeneration"),
|
| 85 |
+
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
|
| 86 |
+
("whisper", "FlaxWhisperForConditionalGeneration"),
|
| 87 |
+
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
|
| 88 |
+
]
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
| 92 |
+
[
|
| 93 |
+
# Model for Masked LM mapping
|
| 94 |
+
("albert", "FlaxAlbertForMaskedLM"),
|
| 95 |
+
("bart", "FlaxBartForConditionalGeneration"),
|
| 96 |
+
("bert", "FlaxBertForMaskedLM"),
|
| 97 |
+
("big_bird", "FlaxBigBirdForMaskedLM"),
|
| 98 |
+
("distilbert", "FlaxDistilBertForMaskedLM"),
|
| 99 |
+
("electra", "FlaxElectraForMaskedLM"),
|
| 100 |
+
("mbart", "FlaxMBartForConditionalGeneration"),
|
| 101 |
+
("roberta", "FlaxRobertaForMaskedLM"),
|
| 102 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
|
| 103 |
+
("roformer", "FlaxRoFormerForMaskedLM"),
|
| 104 |
+
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
|
| 105 |
+
]
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
| 109 |
+
[
|
| 110 |
+
# Model for Seq2Seq Causal LM mapping
|
| 111 |
+
("bart", "FlaxBartForConditionalGeneration"),
|
| 112 |
+
("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
|
| 113 |
+
("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
|
| 114 |
+
("encoder-decoder", "FlaxEncoderDecoderModel"),
|
| 115 |
+
("longt5", "FlaxLongT5ForConditionalGeneration"),
|
| 116 |
+
("marian", "FlaxMarianMTModel"),
|
| 117 |
+
("mbart", "FlaxMBartForConditionalGeneration"),
|
| 118 |
+
("mt5", "FlaxMT5ForConditionalGeneration"),
|
| 119 |
+
("pegasus", "FlaxPegasusForConditionalGeneration"),
|
| 120 |
+
("t5", "FlaxT5ForConditionalGeneration"),
|
| 121 |
+
]
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 125 |
+
[
|
| 126 |
+
# Model for Image-classification
|
| 127 |
+
("beit", "FlaxBeitForImageClassification"),
|
| 128 |
+
("dinov2", "FlaxDinov2ForImageClassification"),
|
| 129 |
+
("regnet", "FlaxRegNetForImageClassification"),
|
| 130 |
+
("resnet", "FlaxResNetForImageClassification"),
|
| 131 |
+
("vit", "FlaxViTForImageClassification"),
|
| 132 |
+
]
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
| 136 |
+
[
|
| 137 |
+
("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"),
|
| 138 |
+
]
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
| 142 |
+
[
|
| 143 |
+
# Model for Causal LM mapping
|
| 144 |
+
("bart", "FlaxBartForCausalLM"),
|
| 145 |
+
("bert", "FlaxBertForCausalLM"),
|
| 146 |
+
("big_bird", "FlaxBigBirdForCausalLM"),
|
| 147 |
+
("bloom", "FlaxBloomForCausalLM"),
|
| 148 |
+
("electra", "FlaxElectraForCausalLM"),
|
| 149 |
+
("gemma", "FlaxGemmaForCausalLM"),
|
| 150 |
+
("gpt-sw3", "FlaxGPT2LMHeadModel"),
|
| 151 |
+
("gpt2", "FlaxGPT2LMHeadModel"),
|
| 152 |
+
("gpt_neo", "FlaxGPTNeoForCausalLM"),
|
| 153 |
+
("gptj", "FlaxGPTJForCausalLM"),
|
| 154 |
+
("llama", "FlaxLlamaForCausalLM"),
|
| 155 |
+
("mistral", "FlaxMistralForCausalLM"),
|
| 156 |
+
("opt", "FlaxOPTForCausalLM"),
|
| 157 |
+
("roberta", "FlaxRobertaForCausalLM"),
|
| 158 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
|
| 159 |
+
("xglm", "FlaxXGLMForCausalLM"),
|
| 160 |
+
("xlm-roberta", "FlaxXLMRobertaForCausalLM"),
|
| 161 |
+
]
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 165 |
+
[
|
| 166 |
+
# Model for Sequence Classification mapping
|
| 167 |
+
("albert", "FlaxAlbertForSequenceClassification"),
|
| 168 |
+
("bart", "FlaxBartForSequenceClassification"),
|
| 169 |
+
("bert", "FlaxBertForSequenceClassification"),
|
| 170 |
+
("big_bird", "FlaxBigBirdForSequenceClassification"),
|
| 171 |
+
("distilbert", "FlaxDistilBertForSequenceClassification"),
|
| 172 |
+
("electra", "FlaxElectraForSequenceClassification"),
|
| 173 |
+
("mbart", "FlaxMBartForSequenceClassification"),
|
| 174 |
+
("roberta", "FlaxRobertaForSequenceClassification"),
|
| 175 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForSequenceClassification"),
|
| 176 |
+
("roformer", "FlaxRoFormerForSequenceClassification"),
|
| 177 |
+
("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
|
| 178 |
+
]
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
| 182 |
+
[
|
| 183 |
+
# Model for Question Answering mapping
|
| 184 |
+
("albert", "FlaxAlbertForQuestionAnswering"),
|
| 185 |
+
("bart", "FlaxBartForQuestionAnswering"),
|
| 186 |
+
("bert", "FlaxBertForQuestionAnswering"),
|
| 187 |
+
("big_bird", "FlaxBigBirdForQuestionAnswering"),
|
| 188 |
+
("distilbert", "FlaxDistilBertForQuestionAnswering"),
|
| 189 |
+
("electra", "FlaxElectraForQuestionAnswering"),
|
| 190 |
+
("mbart", "FlaxMBartForQuestionAnswering"),
|
| 191 |
+
("roberta", "FlaxRobertaForQuestionAnswering"),
|
| 192 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForQuestionAnswering"),
|
| 193 |
+
("roformer", "FlaxRoFormerForQuestionAnswering"),
|
| 194 |
+
("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
|
| 195 |
+
]
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 199 |
+
[
|
| 200 |
+
# Model for Token Classification mapping
|
| 201 |
+
("albert", "FlaxAlbertForTokenClassification"),
|
| 202 |
+
("bert", "FlaxBertForTokenClassification"),
|
| 203 |
+
("big_bird", "FlaxBigBirdForTokenClassification"),
|
| 204 |
+
("distilbert", "FlaxDistilBertForTokenClassification"),
|
| 205 |
+
("electra", "FlaxElectraForTokenClassification"),
|
| 206 |
+
("roberta", "FlaxRobertaForTokenClassification"),
|
| 207 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForTokenClassification"),
|
| 208 |
+
("roformer", "FlaxRoFormerForTokenClassification"),
|
| 209 |
+
("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
|
| 210 |
+
]
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
| 214 |
+
[
|
| 215 |
+
# Model for Multiple Choice mapping
|
| 216 |
+
("albert", "FlaxAlbertForMultipleChoice"),
|
| 217 |
+
("bert", "FlaxBertForMultipleChoice"),
|
| 218 |
+
("big_bird", "FlaxBigBirdForMultipleChoice"),
|
| 219 |
+
("distilbert", "FlaxDistilBertForMultipleChoice"),
|
| 220 |
+
("electra", "FlaxElectraForMultipleChoice"),
|
| 221 |
+
("roberta", "FlaxRobertaForMultipleChoice"),
|
| 222 |
+
("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMultipleChoice"),
|
| 223 |
+
("roformer", "FlaxRoFormerForMultipleChoice"),
|
| 224 |
+
("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
|
| 225 |
+
]
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
| 229 |
+
[
|
| 230 |
+
("bert", "FlaxBertForNextSentencePrediction"),
|
| 231 |
+
]
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
| 235 |
+
[
|
| 236 |
+
("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
|
| 237 |
+
("whisper", "FlaxWhisperForConditionalGeneration"),
|
| 238 |
+
]
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 242 |
+
[
|
| 243 |
+
("whisper", "FlaxWhisperForAudioClassification"),
|
| 244 |
+
]
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
|
| 248 |
+
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
| 249 |
+
FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
| 250 |
+
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
| 251 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
| 252 |
+
)
|
| 253 |
+
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 254 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
| 255 |
+
)
|
| 256 |
+
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
|
| 257 |
+
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
| 258 |
+
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 259 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
| 260 |
+
)
|
| 261 |
+
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
| 262 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
| 263 |
+
)
|
| 264 |
+
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 265 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
| 266 |
+
)
|
| 267 |
+
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
|
| 268 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
|
| 269 |
+
)
|
| 270 |
+
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
|
| 271 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
|
| 272 |
+
)
|
| 273 |
+
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
|
| 274 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
| 275 |
+
)
|
| 276 |
+
FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 277 |
+
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class FlaxAutoModel(_BaseAutoModelClass):
|
| 282 |
+
_model_mapping = FLAX_MODEL_MAPPING
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
FlaxAutoModel = auto_class_update(FlaxAutoModel)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class FlaxAutoModelForPreTraining(_BaseAutoModelClass):
|
| 289 |
+
_model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining")
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class FlaxAutoModelForCausalLM(_BaseAutoModelClass):
|
| 296 |
+
_model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling")
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class FlaxAutoModelForMaskedLM(_BaseAutoModelClass):
|
| 303 |
+
_model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling")
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass):
|
| 310 |
+
_model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
FlaxAutoModelForSeq2SeqLM = auto_class_update(
|
| 314 |
+
FlaxAutoModelForSeq2SeqLM,
|
| 315 |
+
head_doc="sequence-to-sequence language modeling",
|
| 316 |
+
checkpoint_for_example="google-t5/t5-base",
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass):
|
| 321 |
+
_model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
FlaxAutoModelForSequenceClassification = auto_class_update(
|
| 325 |
+
FlaxAutoModelForSequenceClassification, head_doc="sequence classification"
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass):
|
| 330 |
+
_model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering")
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class FlaxAutoModelForTokenClassification(_BaseAutoModelClass):
|
| 337 |
+
_model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
FlaxAutoModelForTokenClassification = auto_class_update(
|
| 341 |
+
FlaxAutoModelForTokenClassification, head_doc="token classification"
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass):
|
| 346 |
+
_model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice")
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass):
|
| 353 |
+
_model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
FlaxAutoModelForNextSentencePrediction = auto_class_update(
|
| 357 |
+
FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
class FlaxAutoModelForImageClassification(_BaseAutoModelClass):
|
| 362 |
+
_model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
FlaxAutoModelForImageClassification = auto_class_update(
|
| 366 |
+
FlaxAutoModelForImageClassification, head_doc="image classification"
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
|
| 371 |
+
_model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
| 378 |
+
_model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
|
| 382 |
+
FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
__all__ = [
|
| 386 |
+
"FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
|
| 387 |
+
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
|
| 388 |
+
"FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
| 389 |
+
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
|
| 390 |
+
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
| 391 |
+
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
| 392 |
+
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
|
| 393 |
+
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
| 394 |
+
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
|
| 395 |
+
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
| 396 |
+
"FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
| 397 |
+
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
| 398 |
+
"FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
| 399 |
+
"FLAX_MODEL_MAPPING",
|
| 400 |
+
"FlaxAutoModel",
|
| 401 |
+
"FlaxAutoModelForCausalLM",
|
| 402 |
+
"FlaxAutoModelForImageClassification",
|
| 403 |
+
"FlaxAutoModelForMaskedLM",
|
| 404 |
+
"FlaxAutoModelForMultipleChoice",
|
| 405 |
+
"FlaxAutoModelForNextSentencePrediction",
|
| 406 |
+
"FlaxAutoModelForPreTraining",
|
| 407 |
+
"FlaxAutoModelForQuestionAnswering",
|
| 408 |
+
"FlaxAutoModelForSeq2SeqLM",
|
| 409 |
+
"FlaxAutoModelForSequenceClassification",
|
| 410 |
+
"FlaxAutoModelForSpeechSeq2Seq",
|
| 411 |
+
"FlaxAutoModelForTokenClassification",
|
| 412 |
+
"FlaxAutoModelForVision2Seq",
|
| 413 |
+
]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/modeling_tf_auto.py
ADDED
|
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Auto Model class."""
|
| 16 |
+
|
| 17 |
+
import warnings
|
| 18 |
+
from collections import OrderedDict
|
| 19 |
+
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
|
| 22 |
+
from .configuration_auto import CONFIG_MAPPING_NAMES
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
TF_MODEL_MAPPING_NAMES = OrderedDict(
|
| 29 |
+
[
|
| 30 |
+
# Base model mapping
|
| 31 |
+
("albert", "TFAlbertModel"),
|
| 32 |
+
("bart", "TFBartModel"),
|
| 33 |
+
("bert", "TFBertModel"),
|
| 34 |
+
("blenderbot", "TFBlenderbotModel"),
|
| 35 |
+
("blenderbot-small", "TFBlenderbotSmallModel"),
|
| 36 |
+
("blip", "TFBlipModel"),
|
| 37 |
+
("camembert", "TFCamembertModel"),
|
| 38 |
+
("clip", "TFCLIPModel"),
|
| 39 |
+
("convbert", "TFConvBertModel"),
|
| 40 |
+
("convnext", "TFConvNextModel"),
|
| 41 |
+
("convnextv2", "TFConvNextV2Model"),
|
| 42 |
+
("ctrl", "TFCTRLModel"),
|
| 43 |
+
("cvt", "TFCvtModel"),
|
| 44 |
+
("data2vec-vision", "TFData2VecVisionModel"),
|
| 45 |
+
("deberta", "TFDebertaModel"),
|
| 46 |
+
("deberta-v2", "TFDebertaV2Model"),
|
| 47 |
+
("deit", "TFDeiTModel"),
|
| 48 |
+
("distilbert", "TFDistilBertModel"),
|
| 49 |
+
("dpr", "TFDPRQuestionEncoder"),
|
| 50 |
+
("efficientformer", "TFEfficientFormerModel"),
|
| 51 |
+
("electra", "TFElectraModel"),
|
| 52 |
+
("esm", "TFEsmModel"),
|
| 53 |
+
("flaubert", "TFFlaubertModel"),
|
| 54 |
+
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
|
| 55 |
+
("gpt-sw3", "TFGPT2Model"),
|
| 56 |
+
("gpt2", "TFGPT2Model"),
|
| 57 |
+
("gptj", "TFGPTJModel"),
|
| 58 |
+
("groupvit", "TFGroupViTModel"),
|
| 59 |
+
("hubert", "TFHubertModel"),
|
| 60 |
+
("idefics", "TFIdeficsModel"),
|
| 61 |
+
("layoutlm", "TFLayoutLMModel"),
|
| 62 |
+
("layoutlmv3", "TFLayoutLMv3Model"),
|
| 63 |
+
("led", "TFLEDModel"),
|
| 64 |
+
("longformer", "TFLongformerModel"),
|
| 65 |
+
("lxmert", "TFLxmertModel"),
|
| 66 |
+
("marian", "TFMarianModel"),
|
| 67 |
+
("mbart", "TFMBartModel"),
|
| 68 |
+
("mistral", "TFMistralModel"),
|
| 69 |
+
("mobilebert", "TFMobileBertModel"),
|
| 70 |
+
("mobilevit", "TFMobileViTModel"),
|
| 71 |
+
("mpnet", "TFMPNetModel"),
|
| 72 |
+
("mt5", "TFMT5Model"),
|
| 73 |
+
("openai-gpt", "TFOpenAIGPTModel"),
|
| 74 |
+
("opt", "TFOPTModel"),
|
| 75 |
+
("pegasus", "TFPegasusModel"),
|
| 76 |
+
("regnet", "TFRegNetModel"),
|
| 77 |
+
("rembert", "TFRemBertModel"),
|
| 78 |
+
("resnet", "TFResNetModel"),
|
| 79 |
+
("roberta", "TFRobertaModel"),
|
| 80 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
|
| 81 |
+
("roformer", "TFRoFormerModel"),
|
| 82 |
+
("sam", "TFSamModel"),
|
| 83 |
+
("sam_vision_model", "TFSamVisionModel"),
|
| 84 |
+
("segformer", "TFSegformerModel"),
|
| 85 |
+
("speech_to_text", "TFSpeech2TextModel"),
|
| 86 |
+
("swiftformer", "TFSwiftFormerModel"),
|
| 87 |
+
("swin", "TFSwinModel"),
|
| 88 |
+
("t5", "TFT5Model"),
|
| 89 |
+
("tapas", "TFTapasModel"),
|
| 90 |
+
("transfo-xl", "TFTransfoXLModel"),
|
| 91 |
+
("vision-text-dual-encoder", "TFVisionTextDualEncoderModel"),
|
| 92 |
+
("vit", "TFViTModel"),
|
| 93 |
+
("vit_mae", "TFViTMAEModel"),
|
| 94 |
+
("wav2vec2", "TFWav2Vec2Model"),
|
| 95 |
+
("whisper", "TFWhisperModel"),
|
| 96 |
+
("xglm", "TFXGLMModel"),
|
| 97 |
+
("xlm", "TFXLMModel"),
|
| 98 |
+
("xlm-roberta", "TFXLMRobertaModel"),
|
| 99 |
+
("xlnet", "TFXLNetModel"),
|
| 100 |
+
]
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
| 104 |
+
[
|
| 105 |
+
# Model for pre-training mapping
|
| 106 |
+
("albert", "TFAlbertForPreTraining"),
|
| 107 |
+
("bart", "TFBartForConditionalGeneration"),
|
| 108 |
+
("bert", "TFBertForPreTraining"),
|
| 109 |
+
("camembert", "TFCamembertForMaskedLM"),
|
| 110 |
+
("ctrl", "TFCTRLLMHeadModel"),
|
| 111 |
+
("distilbert", "TFDistilBertForMaskedLM"),
|
| 112 |
+
("electra", "TFElectraForPreTraining"),
|
| 113 |
+
("flaubert", "TFFlaubertWithLMHeadModel"),
|
| 114 |
+
("funnel", "TFFunnelForPreTraining"),
|
| 115 |
+
("gpt-sw3", "TFGPT2LMHeadModel"),
|
| 116 |
+
("gpt2", "TFGPT2LMHeadModel"),
|
| 117 |
+
("idefics", "TFIdeficsForVisionText2Text"),
|
| 118 |
+
("layoutlm", "TFLayoutLMForMaskedLM"),
|
| 119 |
+
("lxmert", "TFLxmertForPreTraining"),
|
| 120 |
+
("mobilebert", "TFMobileBertForPreTraining"),
|
| 121 |
+
("mpnet", "TFMPNetForMaskedLM"),
|
| 122 |
+
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
| 123 |
+
("roberta", "TFRobertaForMaskedLM"),
|
| 124 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
|
| 125 |
+
("t5", "TFT5ForConditionalGeneration"),
|
| 126 |
+
("tapas", "TFTapasForMaskedLM"),
|
| 127 |
+
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
| 128 |
+
("vit_mae", "TFViTMAEForPreTraining"),
|
| 129 |
+
("xlm", "TFXLMWithLMHeadModel"),
|
| 130 |
+
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
| 131 |
+
("xlnet", "TFXLNetLMHeadModel"),
|
| 132 |
+
]
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
| 136 |
+
[
|
| 137 |
+
# Model with LM heads mapping
|
| 138 |
+
("albert", "TFAlbertForMaskedLM"),
|
| 139 |
+
("bart", "TFBartForConditionalGeneration"),
|
| 140 |
+
("bert", "TFBertForMaskedLM"),
|
| 141 |
+
("camembert", "TFCamembertForMaskedLM"),
|
| 142 |
+
("convbert", "TFConvBertForMaskedLM"),
|
| 143 |
+
("ctrl", "TFCTRLLMHeadModel"),
|
| 144 |
+
("distilbert", "TFDistilBertForMaskedLM"),
|
| 145 |
+
("electra", "TFElectraForMaskedLM"),
|
| 146 |
+
("esm", "TFEsmForMaskedLM"),
|
| 147 |
+
("flaubert", "TFFlaubertWithLMHeadModel"),
|
| 148 |
+
("funnel", "TFFunnelForMaskedLM"),
|
| 149 |
+
("gpt-sw3", "TFGPT2LMHeadModel"),
|
| 150 |
+
("gpt2", "TFGPT2LMHeadModel"),
|
| 151 |
+
("gptj", "TFGPTJForCausalLM"),
|
| 152 |
+
("layoutlm", "TFLayoutLMForMaskedLM"),
|
| 153 |
+
("led", "TFLEDForConditionalGeneration"),
|
| 154 |
+
("longformer", "TFLongformerForMaskedLM"),
|
| 155 |
+
("marian", "TFMarianMTModel"),
|
| 156 |
+
("mobilebert", "TFMobileBertForMaskedLM"),
|
| 157 |
+
("mpnet", "TFMPNetForMaskedLM"),
|
| 158 |
+
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
| 159 |
+
("rembert", "TFRemBertForMaskedLM"),
|
| 160 |
+
("roberta", "TFRobertaForMaskedLM"),
|
| 161 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
|
| 162 |
+
("roformer", "TFRoFormerForMaskedLM"),
|
| 163 |
+
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
|
| 164 |
+
("t5", "TFT5ForConditionalGeneration"),
|
| 165 |
+
("tapas", "TFTapasForMaskedLM"),
|
| 166 |
+
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
| 167 |
+
("whisper", "TFWhisperForConditionalGeneration"),
|
| 168 |
+
("xlm", "TFXLMWithLMHeadModel"),
|
| 169 |
+
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
| 170 |
+
("xlnet", "TFXLNetLMHeadModel"),
|
| 171 |
+
]
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
| 175 |
+
[
|
| 176 |
+
# Model for Causal LM mapping
|
| 177 |
+
("bert", "TFBertLMHeadModel"),
|
| 178 |
+
("camembert", "TFCamembertForCausalLM"),
|
| 179 |
+
("ctrl", "TFCTRLLMHeadModel"),
|
| 180 |
+
("gpt-sw3", "TFGPT2LMHeadModel"),
|
| 181 |
+
("gpt2", "TFGPT2LMHeadModel"),
|
| 182 |
+
("gptj", "TFGPTJForCausalLM"),
|
| 183 |
+
("mistral", "TFMistralForCausalLM"),
|
| 184 |
+
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
| 185 |
+
("opt", "TFOPTForCausalLM"),
|
| 186 |
+
("rembert", "TFRemBertForCausalLM"),
|
| 187 |
+
("roberta", "TFRobertaForCausalLM"),
|
| 188 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForCausalLM"),
|
| 189 |
+
("roformer", "TFRoFormerForCausalLM"),
|
| 190 |
+
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
| 191 |
+
("xglm", "TFXGLMForCausalLM"),
|
| 192 |
+
("xlm", "TFXLMWithLMHeadModel"),
|
| 193 |
+
("xlm-roberta", "TFXLMRobertaForCausalLM"),
|
| 194 |
+
("xlnet", "TFXLNetLMHeadModel"),
|
| 195 |
+
]
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
|
| 199 |
+
[
|
| 200 |
+
("deit", "TFDeiTForMaskedImageModeling"),
|
| 201 |
+
("swin", "TFSwinForMaskedImageModeling"),
|
| 202 |
+
]
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 206 |
+
[
|
| 207 |
+
# Model for Image-classsification
|
| 208 |
+
("convnext", "TFConvNextForImageClassification"),
|
| 209 |
+
("convnextv2", "TFConvNextV2ForImageClassification"),
|
| 210 |
+
("cvt", "TFCvtForImageClassification"),
|
| 211 |
+
("data2vec-vision", "TFData2VecVisionForImageClassification"),
|
| 212 |
+
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
|
| 213 |
+
(
|
| 214 |
+
"efficientformer",
|
| 215 |
+
("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"),
|
| 216 |
+
),
|
| 217 |
+
("mobilevit", "TFMobileViTForImageClassification"),
|
| 218 |
+
("regnet", "TFRegNetForImageClassification"),
|
| 219 |
+
("resnet", "TFResNetForImageClassification"),
|
| 220 |
+
("segformer", "TFSegformerForImageClassification"),
|
| 221 |
+
("swiftformer", "TFSwiftFormerForImageClassification"),
|
| 222 |
+
("swin", "TFSwinForImageClassification"),
|
| 223 |
+
("vit", "TFViTForImageClassification"),
|
| 224 |
+
]
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 229 |
+
[
|
| 230 |
+
# Model for Zero Shot Image Classification mapping
|
| 231 |
+
("blip", "TFBlipModel"),
|
| 232 |
+
("clip", "TFCLIPModel"),
|
| 233 |
+
]
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
| 238 |
+
[
|
| 239 |
+
# Model for Semantic Segmentation mapping
|
| 240 |
+
("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
|
| 241 |
+
("mobilevit", "TFMobileViTForSemanticSegmentation"),
|
| 242 |
+
("segformer", "TFSegformerForSemanticSegmentation"),
|
| 243 |
+
]
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
| 247 |
+
[
|
| 248 |
+
("blip", "TFBlipForConditionalGeneration"),
|
| 249 |
+
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
|
| 250 |
+
]
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
| 254 |
+
[
|
| 255 |
+
# Model for Masked LM mapping
|
| 256 |
+
("albert", "TFAlbertForMaskedLM"),
|
| 257 |
+
("bert", "TFBertForMaskedLM"),
|
| 258 |
+
("camembert", "TFCamembertForMaskedLM"),
|
| 259 |
+
("convbert", "TFConvBertForMaskedLM"),
|
| 260 |
+
("deberta", "TFDebertaForMaskedLM"),
|
| 261 |
+
("deberta-v2", "TFDebertaV2ForMaskedLM"),
|
| 262 |
+
("distilbert", "TFDistilBertForMaskedLM"),
|
| 263 |
+
("electra", "TFElectraForMaskedLM"),
|
| 264 |
+
("esm", "TFEsmForMaskedLM"),
|
| 265 |
+
("flaubert", "TFFlaubertWithLMHeadModel"),
|
| 266 |
+
("funnel", "TFFunnelForMaskedLM"),
|
| 267 |
+
("layoutlm", "TFLayoutLMForMaskedLM"),
|
| 268 |
+
("longformer", "TFLongformerForMaskedLM"),
|
| 269 |
+
("mobilebert", "TFMobileBertForMaskedLM"),
|
| 270 |
+
("mpnet", "TFMPNetForMaskedLM"),
|
| 271 |
+
("rembert", "TFRemBertForMaskedLM"),
|
| 272 |
+
("roberta", "TFRobertaForMaskedLM"),
|
| 273 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
|
| 274 |
+
("roformer", "TFRoFormerForMaskedLM"),
|
| 275 |
+
("tapas", "TFTapasForMaskedLM"),
|
| 276 |
+
("xlm", "TFXLMWithLMHeadModel"),
|
| 277 |
+
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
| 278 |
+
]
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
| 282 |
+
[
|
| 283 |
+
# Model for Seq2Seq Causal LM mapping
|
| 284 |
+
("bart", "TFBartForConditionalGeneration"),
|
| 285 |
+
("blenderbot", "TFBlenderbotForConditionalGeneration"),
|
| 286 |
+
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
|
| 287 |
+
("encoder-decoder", "TFEncoderDecoderModel"),
|
| 288 |
+
("led", "TFLEDForConditionalGeneration"),
|
| 289 |
+
("marian", "TFMarianMTModel"),
|
| 290 |
+
("mbart", "TFMBartForConditionalGeneration"),
|
| 291 |
+
("mt5", "TFMT5ForConditionalGeneration"),
|
| 292 |
+
("pegasus", "TFPegasusForConditionalGeneration"),
|
| 293 |
+
("t5", "TFT5ForConditionalGeneration"),
|
| 294 |
+
]
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
| 298 |
+
[
|
| 299 |
+
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
|
| 300 |
+
("whisper", "TFWhisperForConditionalGeneration"),
|
| 301 |
+
]
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 305 |
+
[
|
| 306 |
+
# Model for Sequence Classification mapping
|
| 307 |
+
("albert", "TFAlbertForSequenceClassification"),
|
| 308 |
+
("bart", "TFBartForSequenceClassification"),
|
| 309 |
+
("bert", "TFBertForSequenceClassification"),
|
| 310 |
+
("camembert", "TFCamembertForSequenceClassification"),
|
| 311 |
+
("convbert", "TFConvBertForSequenceClassification"),
|
| 312 |
+
("ctrl", "TFCTRLForSequenceClassification"),
|
| 313 |
+
("deberta", "TFDebertaForSequenceClassification"),
|
| 314 |
+
("deberta-v2", "TFDebertaV2ForSequenceClassification"),
|
| 315 |
+
("distilbert", "TFDistilBertForSequenceClassification"),
|
| 316 |
+
("electra", "TFElectraForSequenceClassification"),
|
| 317 |
+
("esm", "TFEsmForSequenceClassification"),
|
| 318 |
+
("flaubert", "TFFlaubertForSequenceClassification"),
|
| 319 |
+
("funnel", "TFFunnelForSequenceClassification"),
|
| 320 |
+
("gpt-sw3", "TFGPT2ForSequenceClassification"),
|
| 321 |
+
("gpt2", "TFGPT2ForSequenceClassification"),
|
| 322 |
+
("gptj", "TFGPTJForSequenceClassification"),
|
| 323 |
+
("layoutlm", "TFLayoutLMForSequenceClassification"),
|
| 324 |
+
("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"),
|
| 325 |
+
("longformer", "TFLongformerForSequenceClassification"),
|
| 326 |
+
("mistral", "TFMistralForSequenceClassification"),
|
| 327 |
+
("mobilebert", "TFMobileBertForSequenceClassification"),
|
| 328 |
+
("mpnet", "TFMPNetForSequenceClassification"),
|
| 329 |
+
("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
|
| 330 |
+
("rembert", "TFRemBertForSequenceClassification"),
|
| 331 |
+
("roberta", "TFRobertaForSequenceClassification"),
|
| 332 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForSequenceClassification"),
|
| 333 |
+
("roformer", "TFRoFormerForSequenceClassification"),
|
| 334 |
+
("tapas", "TFTapasForSequenceClassification"),
|
| 335 |
+
("transfo-xl", "TFTransfoXLForSequenceClassification"),
|
| 336 |
+
("xlm", "TFXLMForSequenceClassification"),
|
| 337 |
+
("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
|
| 338 |
+
("xlnet", "TFXLNetForSequenceClassification"),
|
| 339 |
+
]
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
| 343 |
+
[
|
| 344 |
+
# Model for Question Answering mapping
|
| 345 |
+
("albert", "TFAlbertForQuestionAnswering"),
|
| 346 |
+
("bert", "TFBertForQuestionAnswering"),
|
| 347 |
+
("camembert", "TFCamembertForQuestionAnswering"),
|
| 348 |
+
("convbert", "TFConvBertForQuestionAnswering"),
|
| 349 |
+
("deberta", "TFDebertaForQuestionAnswering"),
|
| 350 |
+
("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
|
| 351 |
+
("distilbert", "TFDistilBertForQuestionAnswering"),
|
| 352 |
+
("electra", "TFElectraForQuestionAnswering"),
|
| 353 |
+
("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
|
| 354 |
+
("funnel", "TFFunnelForQuestionAnswering"),
|
| 355 |
+
("gptj", "TFGPTJForQuestionAnswering"),
|
| 356 |
+
("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
|
| 357 |
+
("longformer", "TFLongformerForQuestionAnswering"),
|
| 358 |
+
("mobilebert", "TFMobileBertForQuestionAnswering"),
|
| 359 |
+
("mpnet", "TFMPNetForQuestionAnswering"),
|
| 360 |
+
("rembert", "TFRemBertForQuestionAnswering"),
|
| 361 |
+
("roberta", "TFRobertaForQuestionAnswering"),
|
| 362 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForQuestionAnswering"),
|
| 363 |
+
("roformer", "TFRoFormerForQuestionAnswering"),
|
| 364 |
+
("xlm", "TFXLMForQuestionAnsweringSimple"),
|
| 365 |
+
("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
|
| 366 |
+
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
|
| 367 |
+
]
|
| 368 |
+
)
|
| 369 |
+
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")])
|
| 370 |
+
|
| 371 |
+
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
| 372 |
+
[
|
| 373 |
+
("layoutlm", "TFLayoutLMForQuestionAnswering"),
|
| 374 |
+
("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
|
| 375 |
+
]
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
| 380 |
+
[
|
| 381 |
+
# Model for Table Question Answering mapping
|
| 382 |
+
("tapas", "TFTapasForQuestionAnswering"),
|
| 383 |
+
]
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
| 387 |
+
[
|
| 388 |
+
# Model for Token Classification mapping
|
| 389 |
+
("albert", "TFAlbertForTokenClassification"),
|
| 390 |
+
("bert", "TFBertForTokenClassification"),
|
| 391 |
+
("camembert", "TFCamembertForTokenClassification"),
|
| 392 |
+
("convbert", "TFConvBertForTokenClassification"),
|
| 393 |
+
("deberta", "TFDebertaForTokenClassification"),
|
| 394 |
+
("deberta-v2", "TFDebertaV2ForTokenClassification"),
|
| 395 |
+
("distilbert", "TFDistilBertForTokenClassification"),
|
| 396 |
+
("electra", "TFElectraForTokenClassification"),
|
| 397 |
+
("esm", "TFEsmForTokenClassification"),
|
| 398 |
+
("flaubert", "TFFlaubertForTokenClassification"),
|
| 399 |
+
("funnel", "TFFunnelForTokenClassification"),
|
| 400 |
+
("layoutlm", "TFLayoutLMForTokenClassification"),
|
| 401 |
+
("layoutlmv3", "TFLayoutLMv3ForTokenClassification"),
|
| 402 |
+
("longformer", "TFLongformerForTokenClassification"),
|
| 403 |
+
("mobilebert", "TFMobileBertForTokenClassification"),
|
| 404 |
+
("mpnet", "TFMPNetForTokenClassification"),
|
| 405 |
+
("rembert", "TFRemBertForTokenClassification"),
|
| 406 |
+
("roberta", "TFRobertaForTokenClassification"),
|
| 407 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForTokenClassification"),
|
| 408 |
+
("roformer", "TFRoFormerForTokenClassification"),
|
| 409 |
+
("xlm", "TFXLMForTokenClassification"),
|
| 410 |
+
("xlm-roberta", "TFXLMRobertaForTokenClassification"),
|
| 411 |
+
("xlnet", "TFXLNetForTokenClassification"),
|
| 412 |
+
]
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
| 416 |
+
[
|
| 417 |
+
# Model for Multiple Choice mapping
|
| 418 |
+
("albert", "TFAlbertForMultipleChoice"),
|
| 419 |
+
("bert", "TFBertForMultipleChoice"),
|
| 420 |
+
("camembert", "TFCamembertForMultipleChoice"),
|
| 421 |
+
("convbert", "TFConvBertForMultipleChoice"),
|
| 422 |
+
("deberta-v2", "TFDebertaV2ForMultipleChoice"),
|
| 423 |
+
("distilbert", "TFDistilBertForMultipleChoice"),
|
| 424 |
+
("electra", "TFElectraForMultipleChoice"),
|
| 425 |
+
("flaubert", "TFFlaubertForMultipleChoice"),
|
| 426 |
+
("funnel", "TFFunnelForMultipleChoice"),
|
| 427 |
+
("longformer", "TFLongformerForMultipleChoice"),
|
| 428 |
+
("mobilebert", "TFMobileBertForMultipleChoice"),
|
| 429 |
+
("mpnet", "TFMPNetForMultipleChoice"),
|
| 430 |
+
("rembert", "TFRemBertForMultipleChoice"),
|
| 431 |
+
("roberta", "TFRobertaForMultipleChoice"),
|
| 432 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormForMultipleChoice"),
|
| 433 |
+
("roformer", "TFRoFormerForMultipleChoice"),
|
| 434 |
+
("xlm", "TFXLMForMultipleChoice"),
|
| 435 |
+
("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
|
| 436 |
+
("xlnet", "TFXLNetForMultipleChoice"),
|
| 437 |
+
]
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
| 441 |
+
[
|
| 442 |
+
("bert", "TFBertForNextSentencePrediction"),
|
| 443 |
+
("mobilebert", "TFMobileBertForNextSentencePrediction"),
|
| 444 |
+
]
|
| 445 |
+
)
|
| 446 |
+
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
|
| 447 |
+
[
|
| 448 |
+
("sam", "TFSamModel"),
|
| 449 |
+
]
|
| 450 |
+
)
|
| 451 |
+
TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
|
| 452 |
+
[
|
| 453 |
+
("albert", "TFAlbertModel"),
|
| 454 |
+
("bert", "TFBertModel"),
|
| 455 |
+
("convbert", "TFConvBertModel"),
|
| 456 |
+
("deberta", "TFDebertaModel"),
|
| 457 |
+
("deberta-v2", "TFDebertaV2Model"),
|
| 458 |
+
("distilbert", "TFDistilBertModel"),
|
| 459 |
+
("electra", "TFElectraModel"),
|
| 460 |
+
("flaubert", "TFFlaubertModel"),
|
| 461 |
+
("longformer", "TFLongformerModel"),
|
| 462 |
+
("mobilebert", "TFMobileBertModel"),
|
| 463 |
+
("mt5", "TFMT5EncoderModel"),
|
| 464 |
+
("rembert", "TFRemBertModel"),
|
| 465 |
+
("roberta", "TFRobertaModel"),
|
| 466 |
+
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
|
| 467 |
+
("roformer", "TFRoFormerModel"),
|
| 468 |
+
("t5", "TFT5EncoderModel"),
|
| 469 |
+
("xlm", "TFXLMModel"),
|
| 470 |
+
("xlm-roberta", "TFXLMRobertaModel"),
|
| 471 |
+
]
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
|
| 475 |
+
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
| 476 |
+
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
|
| 477 |
+
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
| 478 |
+
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
| 479 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
|
| 480 |
+
)
|
| 481 |
+
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 482 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
| 483 |
+
)
|
| 484 |
+
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 485 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
| 486 |
+
)
|
| 487 |
+
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
| 488 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
|
| 489 |
+
)
|
| 490 |
+
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
|
| 491 |
+
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
| 492 |
+
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
| 493 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
| 494 |
+
)
|
| 495 |
+
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 496 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
| 497 |
+
)
|
| 498 |
+
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
|
| 499 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
| 500 |
+
)
|
| 501 |
+
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
| 502 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
| 503 |
+
)
|
| 504 |
+
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
| 505 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
|
| 506 |
+
)
|
| 507 |
+
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
| 508 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
|
| 509 |
+
)
|
| 510 |
+
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 511 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
| 512 |
+
)
|
| 513 |
+
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
|
| 514 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
|
| 515 |
+
)
|
| 516 |
+
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
|
| 517 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
|
| 518 |
+
)
|
| 519 |
+
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
| 520 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
|
| 524 |
+
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
|
| 531 |
+
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class TFAutoModelForTextEncoding(_BaseAutoModelClass):
|
| 535 |
+
_model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class TFAutoModel(_BaseAutoModelClass):
|
| 539 |
+
_model_mapping = TF_MODEL_MAPPING
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
TFAutoModel = auto_class_update(TFAutoModel)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
class TFAutoModelForAudioClassification(_BaseAutoModelClass):
|
| 546 |
+
_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
TFAutoModelForAudioClassification = auto_class_update(
|
| 550 |
+
TFAutoModelForAudioClassification, head_doc="audio classification"
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
class TFAutoModelForPreTraining(_BaseAutoModelClass):
|
| 555 |
+
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
# Private on purpose, the public class will add the deprecation warnings.
|
| 562 |
+
class _TFAutoModelWithLMHead(_BaseAutoModelClass):
|
| 563 |
+
_model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
_TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
class TFAutoModelForCausalLM(_BaseAutoModelClass):
|
| 570 |
+
_model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass):
|
| 577 |
+
_model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
TFAutoModelForMaskedImageModeling = auto_class_update(
|
| 581 |
+
TFAutoModelForMaskedImageModeling, head_doc="masked image modeling"
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
class TFAutoModelForImageClassification(_BaseAutoModelClass):
|
| 586 |
+
_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
TFAutoModelForImageClassification = auto_class_update(
|
| 590 |
+
TFAutoModelForImageClassification, head_doc="image classification"
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
|
| 595 |
+
_model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
TFAutoModelForZeroShotImageClassification = auto_class_update(
|
| 599 |
+
TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
|
| 604 |
+
_model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
TFAutoModelForSemanticSegmentation = auto_class_update(
|
| 608 |
+
TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation"
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class TFAutoModelForVision2Seq(_BaseAutoModelClass):
|
| 613 |
+
_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling")
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
class TFAutoModelForMaskedLM(_BaseAutoModelClass):
|
| 620 |
+
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling")
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
|
| 627 |
+
_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
TFAutoModelForSeq2SeqLM = auto_class_update(
|
| 631 |
+
TFAutoModelForSeq2SeqLM,
|
| 632 |
+
head_doc="sequence-to-sequence language modeling",
|
| 633 |
+
checkpoint_for_example="google-t5/t5-base",
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
|
| 638 |
+
_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
TFAutoModelForSequenceClassification = auto_class_update(
|
| 642 |
+
TFAutoModelForSequenceClassification, head_doc="sequence classification"
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
|
| 647 |
+
_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
|
| 654 |
+
_model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
TFAutoModelForDocumentQuestionAnswering = auto_class_update(
|
| 658 |
+
TFAutoModelForDocumentQuestionAnswering,
|
| 659 |
+
head_doc="document question answering",
|
| 660 |
+
checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass):
|
| 665 |
+
_model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
TFAutoModelForTableQuestionAnswering = auto_class_update(
|
| 669 |
+
TFAutoModelForTableQuestionAnswering,
|
| 670 |
+
head_doc="table question answering",
|
| 671 |
+
checkpoint_for_example="google/tapas-base-finetuned-wtq",
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
class TFAutoModelForTokenClassification(_BaseAutoModelClass):
|
| 676 |
+
_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
TFAutoModelForTokenClassification = auto_class_update(
|
| 680 |
+
TFAutoModelForTokenClassification, head_doc="token classification"
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
class TFAutoModelForMultipleChoice(_BaseAutoModelClass):
|
| 685 |
+
_model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice")
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
|
| 692 |
+
_model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
TFAutoModelForNextSentencePrediction = auto_class_update(
|
| 696 |
+
TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
| 701 |
+
_model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
TFAutoModelForSpeechSeq2Seq = auto_class_update(
|
| 705 |
+
TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
class TFAutoModelWithLMHead(_TFAutoModelWithLMHead):
|
| 710 |
+
@classmethod
|
| 711 |
+
def from_config(cls, config):
|
| 712 |
+
warnings.warn(
|
| 713 |
+
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
|
| 714 |
+
" `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
|
| 715 |
+
" and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
|
| 716 |
+
FutureWarning,
|
| 717 |
+
)
|
| 718 |
+
return super().from_config(config)
|
| 719 |
+
|
| 720 |
+
@classmethod
|
| 721 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 722 |
+
warnings.warn(
|
| 723 |
+
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
|
| 724 |
+
" `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
|
| 725 |
+
" and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
|
| 726 |
+
FutureWarning,
|
| 727 |
+
)
|
| 728 |
+
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
__all__ = [
|
| 732 |
+
"TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
|
| 733 |
+
"TF_MODEL_FOR_CAUSAL_LM_MAPPING",
|
| 734 |
+
"TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
| 735 |
+
"TF_MODEL_FOR_MASK_GENERATION_MAPPING",
|
| 736 |
+
"TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
| 737 |
+
"TF_MODEL_FOR_MASKED_LM_MAPPING",
|
| 738 |
+
"TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
| 739 |
+
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
| 740 |
+
"TF_MODEL_FOR_PRETRAINING_MAPPING",
|
| 741 |
+
"TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
| 742 |
+
"TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING",
|
| 743 |
+
"TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
|
| 744 |
+
"TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
|
| 745 |
+
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
| 746 |
+
"TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
| 747 |
+
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
| 748 |
+
"TF_MODEL_FOR_TEXT_ENCODING_MAPPING",
|
| 749 |
+
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
| 750 |
+
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
| 751 |
+
"TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
|
| 752 |
+
"TF_MODEL_MAPPING",
|
| 753 |
+
"TF_MODEL_WITH_LM_HEAD_MAPPING",
|
| 754 |
+
"TFAutoModel",
|
| 755 |
+
"TFAutoModelForAudioClassification",
|
| 756 |
+
"TFAutoModelForCausalLM",
|
| 757 |
+
"TFAutoModelForImageClassification",
|
| 758 |
+
"TFAutoModelForMaskedImageModeling",
|
| 759 |
+
"TFAutoModelForMaskedLM",
|
| 760 |
+
"TFAutoModelForMaskGeneration",
|
| 761 |
+
"TFAutoModelForMultipleChoice",
|
| 762 |
+
"TFAutoModelForNextSentencePrediction",
|
| 763 |
+
"TFAutoModelForPreTraining",
|
| 764 |
+
"TFAutoModelForDocumentQuestionAnswering",
|
| 765 |
+
"TFAutoModelForQuestionAnswering",
|
| 766 |
+
"TFAutoModelForSemanticSegmentation",
|
| 767 |
+
"TFAutoModelForSeq2SeqLM",
|
| 768 |
+
"TFAutoModelForSequenceClassification",
|
| 769 |
+
"TFAutoModelForSpeechSeq2Seq",
|
| 770 |
+
"TFAutoModelForTableQuestionAnswering",
|
| 771 |
+
"TFAutoModelForTextEncoding",
|
| 772 |
+
"TFAutoModelForTokenClassification",
|
| 773 |
+
"TFAutoModelForVision2Seq",
|
| 774 |
+
"TFAutoModelForZeroShotImageClassification",
|
| 775 |
+
"TFAutoModelWithLMHead",
|
| 776 |
+
]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/processing_auto.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""AutoProcessor class."""
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import inspect
|
| 19 |
+
import json
|
| 20 |
+
import warnings
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
|
| 23 |
+
# Build the list of all feature extractors
|
| 24 |
+
from ...configuration_utils import PretrainedConfig
|
| 25 |
+
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
| 26 |
+
from ...feature_extraction_utils import FeatureExtractionMixin
|
| 27 |
+
from ...image_processing_utils import ImageProcessingMixin
|
| 28 |
+
from ...processing_utils import ProcessorMixin
|
| 29 |
+
from ...tokenization_utils import TOKENIZER_CONFIG_FILE
|
| 30 |
+
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, VIDEO_PROCESSOR_NAME, cached_file, logging
|
| 31 |
+
from ...video_processing_utils import BaseVideoProcessor
|
| 32 |
+
from .auto_factory import _LazyAutoMapping
|
| 33 |
+
from .configuration_auto import (
|
| 34 |
+
CONFIG_MAPPING_NAMES,
|
| 35 |
+
AutoConfig,
|
| 36 |
+
model_type_to_module_name,
|
| 37 |
+
replace_list_option_in_docstrings,
|
| 38 |
+
)
|
| 39 |
+
from .feature_extraction_auto import AutoFeatureExtractor
|
| 40 |
+
from .image_processing_auto import AutoImageProcessor
|
| 41 |
+
from .tokenization_auto import AutoTokenizer
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__)
|
| 45 |
+
|
| 46 |
+
PROCESSOR_MAPPING_NAMES = OrderedDict(
|
| 47 |
+
[
|
| 48 |
+
("aimv2", "CLIPProcessor"),
|
| 49 |
+
("align", "AlignProcessor"),
|
| 50 |
+
("altclip", "AltCLIPProcessor"),
|
| 51 |
+
("aria", "AriaProcessor"),
|
| 52 |
+
("aya_vision", "AyaVisionProcessor"),
|
| 53 |
+
("bark", "BarkProcessor"),
|
| 54 |
+
("blip", "BlipProcessor"),
|
| 55 |
+
("blip-2", "Blip2Processor"),
|
| 56 |
+
("bridgetower", "BridgeTowerProcessor"),
|
| 57 |
+
("chameleon", "ChameleonProcessor"),
|
| 58 |
+
("chinese_clip", "ChineseCLIPProcessor"),
|
| 59 |
+
("clap", "ClapProcessor"),
|
| 60 |
+
("clip", "CLIPProcessor"),
|
| 61 |
+
("clipseg", "CLIPSegProcessor"),
|
| 62 |
+
("clvp", "ClvpProcessor"),
|
| 63 |
+
("cohere2_vision", "Cohere2VisionProcessor"),
|
| 64 |
+
("colpali", "ColPaliProcessor"),
|
| 65 |
+
("colqwen2", "ColQwen2Processor"),
|
| 66 |
+
("deepseek_vl", "DeepseekVLProcessor"),
|
| 67 |
+
("deepseek_vl_hybrid", "DeepseekVLHybridProcessor"),
|
| 68 |
+
("dia", "DiaProcessor"),
|
| 69 |
+
("edgetam", "Sam2Processor"),
|
| 70 |
+
("emu3", "Emu3Processor"),
|
| 71 |
+
("evolla", "EvollaProcessor"),
|
| 72 |
+
("flava", "FlavaProcessor"),
|
| 73 |
+
("florence2", "Florence2Processor"),
|
| 74 |
+
("fuyu", "FuyuProcessor"),
|
| 75 |
+
("gemma3", "Gemma3Processor"),
|
| 76 |
+
("gemma3n", "Gemma3nProcessor"),
|
| 77 |
+
("git", "GitProcessor"),
|
| 78 |
+
("glm4v", "Glm4vProcessor"),
|
| 79 |
+
("glm4v_moe", "Glm4vProcessor"),
|
| 80 |
+
("got_ocr2", "GotOcr2Processor"),
|
| 81 |
+
("granite_speech", "GraniteSpeechProcessor"),
|
| 82 |
+
("grounding-dino", "GroundingDinoProcessor"),
|
| 83 |
+
("groupvit", "CLIPProcessor"),
|
| 84 |
+
("hubert", "Wav2Vec2Processor"),
|
| 85 |
+
("idefics", "IdeficsProcessor"),
|
| 86 |
+
("idefics2", "Idefics2Processor"),
|
| 87 |
+
("idefics3", "Idefics3Processor"),
|
| 88 |
+
("instructblip", "InstructBlipProcessor"),
|
| 89 |
+
("instructblipvideo", "InstructBlipVideoProcessor"),
|
| 90 |
+
("internvl", "InternVLProcessor"),
|
| 91 |
+
("janus", "JanusProcessor"),
|
| 92 |
+
("kosmos-2", "Kosmos2Processor"),
|
| 93 |
+
("kosmos-2.5", "Kosmos2_5Processor"),
|
| 94 |
+
("kyutai_speech_to_text", "KyutaiSpeechToTextProcessor"),
|
| 95 |
+
("layoutlmv2", "LayoutLMv2Processor"),
|
| 96 |
+
("layoutlmv3", "LayoutLMv3Processor"),
|
| 97 |
+
("lfm2_vl", "Lfm2VlProcessor"),
|
| 98 |
+
("llama4", "Llama4Processor"),
|
| 99 |
+
("llava", "LlavaProcessor"),
|
| 100 |
+
("llava_next", "LlavaNextProcessor"),
|
| 101 |
+
("llava_next_video", "LlavaNextVideoProcessor"),
|
| 102 |
+
("llava_onevision", "LlavaOnevisionProcessor"),
|
| 103 |
+
("markuplm", "MarkupLMProcessor"),
|
| 104 |
+
("mctct", "MCTCTProcessor"),
|
| 105 |
+
("metaclip_2", "CLIPProcessor"),
|
| 106 |
+
("mgp-str", "MgpstrProcessor"),
|
| 107 |
+
("mistral3", "PixtralProcessor"),
|
| 108 |
+
("mllama", "MllamaProcessor"),
|
| 109 |
+
("mm-grounding-dino", "GroundingDinoProcessor"),
|
| 110 |
+
("moonshine", "Wav2Vec2Processor"),
|
| 111 |
+
("oneformer", "OneFormerProcessor"),
|
| 112 |
+
("ovis2", "Ovis2Processor"),
|
| 113 |
+
("owlv2", "Owlv2Processor"),
|
| 114 |
+
("owlvit", "OwlViTProcessor"),
|
| 115 |
+
("paligemma", "PaliGemmaProcessor"),
|
| 116 |
+
("perception_lm", "PerceptionLMProcessor"),
|
| 117 |
+
("phi4_multimodal", "Phi4MultimodalProcessor"),
|
| 118 |
+
("pix2struct", "Pix2StructProcessor"),
|
| 119 |
+
("pixtral", "PixtralProcessor"),
|
| 120 |
+
("pop2piano", "Pop2PianoProcessor"),
|
| 121 |
+
("qwen2_5_omni", "Qwen2_5OmniProcessor"),
|
| 122 |
+
("qwen2_5_vl", "Qwen2_5_VLProcessor"),
|
| 123 |
+
("qwen2_audio", "Qwen2AudioProcessor"),
|
| 124 |
+
("qwen2_vl", "Qwen2VLProcessor"),
|
| 125 |
+
("qwen3_omni_moe", "Qwen3OmniMoeProcessor"),
|
| 126 |
+
("qwen3_vl", "Qwen3VLProcessor"),
|
| 127 |
+
("qwen3_vl_moe", "Qwen3VLProcessor"),
|
| 128 |
+
("sam", "SamProcessor"),
|
| 129 |
+
("sam2", "Sam2Processor"),
|
| 130 |
+
("sam_hq", "SamHQProcessor"),
|
| 131 |
+
("seamless_m4t", "SeamlessM4TProcessor"),
|
| 132 |
+
("sew", "Wav2Vec2Processor"),
|
| 133 |
+
("sew-d", "Wav2Vec2Processor"),
|
| 134 |
+
("shieldgemma2", "ShieldGemma2Processor"),
|
| 135 |
+
("siglip", "SiglipProcessor"),
|
| 136 |
+
("siglip2", "Siglip2Processor"),
|
| 137 |
+
("smolvlm", "SmolVLMProcessor"),
|
| 138 |
+
("speech_to_text", "Speech2TextProcessor"),
|
| 139 |
+
("speech_to_text_2", "Speech2Text2Processor"),
|
| 140 |
+
("speecht5", "SpeechT5Processor"),
|
| 141 |
+
("trocr", "TrOCRProcessor"),
|
| 142 |
+
("tvlt", "TvltProcessor"),
|
| 143 |
+
("tvp", "TvpProcessor"),
|
| 144 |
+
("udop", "UdopProcessor"),
|
| 145 |
+
("unispeech", "Wav2Vec2Processor"),
|
| 146 |
+
("unispeech-sat", "Wav2Vec2Processor"),
|
| 147 |
+
("video_llava", "VideoLlavaProcessor"),
|
| 148 |
+
("vilt", "ViltProcessor"),
|
| 149 |
+
("vipllava", "LlavaProcessor"),
|
| 150 |
+
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
|
| 151 |
+
("voxtral", "VoxtralProcessor"),
|
| 152 |
+
("wav2vec2", "Wav2Vec2Processor"),
|
| 153 |
+
("wav2vec2-bert", "Wav2Vec2Processor"),
|
| 154 |
+
("wav2vec2-conformer", "Wav2Vec2Processor"),
|
| 155 |
+
("wavlm", "Wav2Vec2Processor"),
|
| 156 |
+
("whisper", "WhisperProcessor"),
|
| 157 |
+
("xclip", "XCLIPProcessor"),
|
| 158 |
+
]
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, PROCESSOR_MAPPING_NAMES)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def processor_class_from_name(class_name: str):
|
| 165 |
+
for module_name, processors in PROCESSOR_MAPPING_NAMES.items():
|
| 166 |
+
if class_name in processors:
|
| 167 |
+
module_name = model_type_to_module_name(module_name)
|
| 168 |
+
|
| 169 |
+
module = importlib.import_module(f".{module_name}", "transformers.models")
|
| 170 |
+
try:
|
| 171 |
+
return getattr(module, class_name)
|
| 172 |
+
except AttributeError:
|
| 173 |
+
continue
|
| 174 |
+
|
| 175 |
+
for processor in PROCESSOR_MAPPING._extra_content.values():
|
| 176 |
+
if getattr(processor, "__name__", None) == class_name:
|
| 177 |
+
return processor
|
| 178 |
+
|
| 179 |
+
# We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
|
| 180 |
+
# init and we return the proper dummy to get an appropriate error message.
|
| 181 |
+
main_module = importlib.import_module("transformers")
|
| 182 |
+
if hasattr(main_module, class_name):
|
| 183 |
+
return getattr(main_module, class_name)
|
| 184 |
+
|
| 185 |
+
return None
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class AutoProcessor:
|
| 189 |
+
r"""
|
| 190 |
+
This is a generic processor class that will be instantiated as one of the processor classes of the library when
|
| 191 |
+
created with the [`AutoProcessor.from_pretrained`] class method.
|
| 192 |
+
|
| 193 |
+
This class cannot be instantiated directly using `__init__()` (throws an error).
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
def __init__(self):
|
| 197 |
+
raise OSError(
|
| 198 |
+
"AutoProcessor is designed to be instantiated "
|
| 199 |
+
"using the `AutoProcessor.from_pretrained(pretrained_model_name_or_path)` method."
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
@classmethod
|
| 203 |
+
@replace_list_option_in_docstrings(PROCESSOR_MAPPING_NAMES)
|
| 204 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 205 |
+
r"""
|
| 206 |
+
Instantiate one of the processor classes of the library from a pretrained model vocabulary.
|
| 207 |
+
|
| 208 |
+
The processor class to instantiate is selected based on the `model_type` property of the config object (either
|
| 209 |
+
passed as an argument or loaded from `pretrained_model_name_or_path` if possible):
|
| 210 |
+
|
| 211 |
+
List options
|
| 212 |
+
|
| 213 |
+
Params:
|
| 214 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 215 |
+
This can be either:
|
| 216 |
+
|
| 217 |
+
- a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
|
| 218 |
+
huggingface.co.
|
| 219 |
+
- a path to a *directory* containing a processor files saved using the `save_pretrained()` method,
|
| 220 |
+
e.g., `./my_model_directory/`.
|
| 221 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 222 |
+
Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
|
| 223 |
+
standard cache should not be used.
|
| 224 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 225 |
+
Whether or not to force to (re-)download the feature extractor files and override the cached versions
|
| 226 |
+
if they exist.
|
| 227 |
+
resume_download:
|
| 228 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 229 |
+
Will be removed in v5 of Transformers.
|
| 230 |
+
proxies (`dict[str, str]`, *optional*):
|
| 231 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 232 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 233 |
+
token (`str` or *bool*, *optional*):
|
| 234 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 235 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 236 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 237 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 238 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 239 |
+
identifier allowed by git.
|
| 240 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 241 |
+
If `False`, then this function returns just the final feature extractor object. If `True`, then this
|
| 242 |
+
functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
|
| 243 |
+
consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
|
| 244 |
+
`kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
|
| 245 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 246 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 247 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 248 |
+
execute code present on the Hub on your local machine.
|
| 249 |
+
kwargs (`dict[str, Any]`, *optional*):
|
| 250 |
+
The values in kwargs of any keys which are feature extractor attributes will be used to override the
|
| 251 |
+
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
|
| 252 |
+
controlled by the `return_unused_kwargs` keyword parameter.
|
| 253 |
+
|
| 254 |
+
<Tip>
|
| 255 |
+
|
| 256 |
+
Passing `token=True` is required when you want to use a private model.
|
| 257 |
+
|
| 258 |
+
</Tip>
|
| 259 |
+
|
| 260 |
+
Examples:
|
| 261 |
+
|
| 262 |
+
```python
|
| 263 |
+
>>> from transformers import AutoProcessor
|
| 264 |
+
|
| 265 |
+
>>> # Download processor from huggingface.co and cache.
|
| 266 |
+
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
|
| 267 |
+
|
| 268 |
+
>>> # If processor files are in a directory (e.g. processor was saved using *save_pretrained('./test/saved_model/')*)
|
| 269 |
+
>>> # processor = AutoProcessor.from_pretrained("./test/saved_model/")
|
| 270 |
+
```"""
|
| 271 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 272 |
+
if use_auth_token is not None:
|
| 273 |
+
warnings.warn(
|
| 274 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 275 |
+
FutureWarning,
|
| 276 |
+
)
|
| 277 |
+
if kwargs.get("token") is not None:
|
| 278 |
+
raise ValueError(
|
| 279 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 280 |
+
)
|
| 281 |
+
kwargs["token"] = use_auth_token
|
| 282 |
+
|
| 283 |
+
config = kwargs.pop("config", None)
|
| 284 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 285 |
+
kwargs["_from_auto"] = True
|
| 286 |
+
|
| 287 |
+
processor_class = None
|
| 288 |
+
processor_auto_map = None
|
| 289 |
+
|
| 290 |
+
# First, let's see if we have a processor or preprocessor config.
|
| 291 |
+
# Filter the kwargs for `cached_file`.
|
| 292 |
+
cached_file_kwargs = {key: kwargs[key] for key in inspect.signature(cached_file).parameters if key in kwargs}
|
| 293 |
+
# We don't want to raise
|
| 294 |
+
cached_file_kwargs.update(
|
| 295 |
+
{
|
| 296 |
+
"_raise_exceptions_for_gated_repo": False,
|
| 297 |
+
"_raise_exceptions_for_missing_entries": False,
|
| 298 |
+
"_raise_exceptions_for_connection_errors": False,
|
| 299 |
+
}
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Let's start by checking whether the processor class is saved in a processor config
|
| 303 |
+
processor_config_file = cached_file(pretrained_model_name_or_path, PROCESSOR_NAME, **cached_file_kwargs)
|
| 304 |
+
if processor_config_file is not None:
|
| 305 |
+
config_dict, _ = ProcessorMixin.get_processor_dict(pretrained_model_name_or_path, **kwargs)
|
| 306 |
+
processor_class = config_dict.get("processor_class", None)
|
| 307 |
+
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
| 308 |
+
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
| 309 |
+
|
| 310 |
+
if processor_class is None:
|
| 311 |
+
# If not found, let's check whether the processor class is saved in an image processor config
|
| 312 |
+
preprocessor_config_file = cached_file(
|
| 313 |
+
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
|
| 314 |
+
)
|
| 315 |
+
if preprocessor_config_file is not None:
|
| 316 |
+
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
|
| 317 |
+
processor_class = config_dict.get("processor_class", None)
|
| 318 |
+
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
| 319 |
+
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
| 320 |
+
|
| 321 |
+
# Saved as video processor
|
| 322 |
+
if preprocessor_config_file is None:
|
| 323 |
+
preprocessor_config_file = cached_file(
|
| 324 |
+
pretrained_model_name_or_path, VIDEO_PROCESSOR_NAME, **cached_file_kwargs
|
| 325 |
+
)
|
| 326 |
+
if preprocessor_config_file is not None:
|
| 327 |
+
config_dict, _ = BaseVideoProcessor.get_video_processor_dict(
|
| 328 |
+
pretrained_model_name_or_path, **kwargs
|
| 329 |
+
)
|
| 330 |
+
processor_class = config_dict.get("processor_class", None)
|
| 331 |
+
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
| 332 |
+
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
| 333 |
+
|
| 334 |
+
# Saved as feature extractor
|
| 335 |
+
if preprocessor_config_file is None:
|
| 336 |
+
preprocessor_config_file = cached_file(
|
| 337 |
+
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
|
| 338 |
+
)
|
| 339 |
+
if preprocessor_config_file is not None and processor_class is None:
|
| 340 |
+
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(
|
| 341 |
+
pretrained_model_name_or_path, **kwargs
|
| 342 |
+
)
|
| 343 |
+
processor_class = config_dict.get("processor_class", None)
|
| 344 |
+
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
| 345 |
+
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
| 346 |
+
|
| 347 |
+
if processor_class is None:
|
| 348 |
+
# Next, let's check whether the processor class is saved in a tokenizer
|
| 349 |
+
tokenizer_config_file = cached_file(
|
| 350 |
+
pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **cached_file_kwargs
|
| 351 |
+
)
|
| 352 |
+
if tokenizer_config_file is not None:
|
| 353 |
+
with open(tokenizer_config_file, encoding="utf-8") as reader:
|
| 354 |
+
config_dict = json.load(reader)
|
| 355 |
+
|
| 356 |
+
processor_class = config_dict.get("processor_class", None)
|
| 357 |
+
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
| 358 |
+
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
| 359 |
+
|
| 360 |
+
if processor_class is None:
|
| 361 |
+
# Otherwise, load config, if it can be loaded.
|
| 362 |
+
if not isinstance(config, PretrainedConfig):
|
| 363 |
+
config = AutoConfig.from_pretrained(
|
| 364 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# And check if the config contains the processor class.
|
| 368 |
+
processor_class = getattr(config, "processor_class", None)
|
| 369 |
+
if hasattr(config, "auto_map") and "AutoProcessor" in config.auto_map:
|
| 370 |
+
processor_auto_map = config.auto_map["AutoProcessor"]
|
| 371 |
+
|
| 372 |
+
if processor_class is not None:
|
| 373 |
+
processor_class = processor_class_from_name(processor_class)
|
| 374 |
+
|
| 375 |
+
has_remote_code = processor_auto_map is not None
|
| 376 |
+
has_local_code = processor_class is not None or type(config) in PROCESSOR_MAPPING
|
| 377 |
+
if has_remote_code:
|
| 378 |
+
if "--" in processor_auto_map:
|
| 379 |
+
upstream_repo = processor_auto_map.split("--")[0]
|
| 380 |
+
else:
|
| 381 |
+
upstream_repo = None
|
| 382 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 383 |
+
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
if has_remote_code and trust_remote_code:
|
| 387 |
+
processor_class = get_class_from_dynamic_module(
|
| 388 |
+
processor_auto_map, pretrained_model_name_or_path, **kwargs
|
| 389 |
+
)
|
| 390 |
+
_ = kwargs.pop("code_revision", None)
|
| 391 |
+
processor_class.register_for_auto_class()
|
| 392 |
+
return processor_class.from_pretrained(
|
| 393 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 394 |
+
)
|
| 395 |
+
elif processor_class is not None:
|
| 396 |
+
return processor_class.from_pretrained(
|
| 397 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 398 |
+
)
|
| 399 |
+
# Last try: we use the PROCESSOR_MAPPING.
|
| 400 |
+
elif type(config) in PROCESSOR_MAPPING:
|
| 401 |
+
return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 402 |
+
|
| 403 |
+
# At this stage, there doesn't seem to be a `Processor` class available for this model, so let's try a
|
| 404 |
+
# tokenizer.
|
| 405 |
+
try:
|
| 406 |
+
return AutoTokenizer.from_pretrained(
|
| 407 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 408 |
+
)
|
| 409 |
+
except Exception:
|
| 410 |
+
try:
|
| 411 |
+
return AutoImageProcessor.from_pretrained(
|
| 412 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 413 |
+
)
|
| 414 |
+
except Exception:
|
| 415 |
+
pass
|
| 416 |
+
|
| 417 |
+
try:
|
| 418 |
+
return AutoFeatureExtractor.from_pretrained(
|
| 419 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 420 |
+
)
|
| 421 |
+
except Exception:
|
| 422 |
+
pass
|
| 423 |
+
|
| 424 |
+
raise ValueError(
|
| 425 |
+
f"Unrecognized processing class in {pretrained_model_name_or_path}. Can't instantiate a processor, a "
|
| 426 |
+
"tokenizer, an image processor or a feature extractor for this model. Make sure the repository contains "
|
| 427 |
+
"the files of at least one of those processing classes."
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
@staticmethod
|
| 431 |
+
def register(config_class, processor_class, exist_ok=False):
|
| 432 |
+
"""
|
| 433 |
+
Register a new processor for this class.
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
config_class ([`PretrainedConfig`]):
|
| 437 |
+
The configuration corresponding to the model to register.
|
| 438 |
+
processor_class ([`ProcessorMixin`]): The processor to register.
|
| 439 |
+
"""
|
| 440 |
+
PROCESSOR_MAPPING.register(config_class, processor_class, exist_ok=exist_ok)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
__all__ = ["PROCESSOR_MAPPING", "AutoProcessor"]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/tokenization_auto.py
ADDED
|
@@ -0,0 +1,1235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Auto Tokenizer class."""
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import warnings
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
from typing import Any, Optional, Union
|
| 23 |
+
|
| 24 |
+
from transformers.utils.import_utils import is_mistral_common_available
|
| 25 |
+
|
| 26 |
+
from ...configuration_utils import PretrainedConfig
|
| 27 |
+
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
| 28 |
+
from ...modeling_gguf_pytorch_utils import load_gguf_checkpoint
|
| 29 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 30 |
+
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
| 31 |
+
from ...utils import (
|
| 32 |
+
cached_file,
|
| 33 |
+
extract_commit_hash,
|
| 34 |
+
is_g2p_en_available,
|
| 35 |
+
is_sentencepiece_available,
|
| 36 |
+
is_tokenizers_available,
|
| 37 |
+
logging,
|
| 38 |
+
)
|
| 39 |
+
from ..encoder_decoder import EncoderDecoderConfig
|
| 40 |
+
from .auto_factory import _LazyAutoMapping
|
| 41 |
+
from .configuration_auto import (
|
| 42 |
+
CONFIG_MAPPING_NAMES,
|
| 43 |
+
AutoConfig,
|
| 44 |
+
config_class_to_model_type,
|
| 45 |
+
model_type_to_module_name,
|
| 46 |
+
replace_list_option_in_docstrings,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if is_tokenizers_available():
|
| 51 |
+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
| 52 |
+
else:
|
| 53 |
+
PreTrainedTokenizerFast = None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
logger = logging.get_logger(__name__)
|
| 57 |
+
|
| 58 |
+
# Explicit rather than inferred generics to significantly improves completion suggestion performance for language servers.
|
| 59 |
+
TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
| 60 |
+
[
|
| 61 |
+
(
|
| 62 |
+
"aimv2",
|
| 63 |
+
(
|
| 64 |
+
"CLIPTokenizer",
|
| 65 |
+
"CLIPTokenizerFast" if is_tokenizers_available() else None,
|
| 66 |
+
),
|
| 67 |
+
),
|
| 68 |
+
(
|
| 69 |
+
"albert",
|
| 70 |
+
(
|
| 71 |
+
"AlbertTokenizer" if is_sentencepiece_available() else None,
|
| 72 |
+
"AlbertTokenizerFast" if is_tokenizers_available() else None,
|
| 73 |
+
),
|
| 74 |
+
),
|
| 75 |
+
("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 76 |
+
("arcee", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 77 |
+
("aria", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 78 |
+
("aya_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
| 79 |
+
("bark", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 80 |
+
("bart", ("BartTokenizer", "BartTokenizerFast")),
|
| 81 |
+
(
|
| 82 |
+
"barthez",
|
| 83 |
+
(
|
| 84 |
+
"BarthezTokenizer" if is_sentencepiece_available() else None,
|
| 85 |
+
"BarthezTokenizerFast" if is_tokenizers_available() else None,
|
| 86 |
+
),
|
| 87 |
+
),
|
| 88 |
+
("bartpho", ("BartphoTokenizer", None)),
|
| 89 |
+
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 90 |
+
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
|
| 91 |
+
("bert-japanese", ("BertJapaneseTokenizer", None)),
|
| 92 |
+
("bertweet", ("BertweetTokenizer", None)),
|
| 93 |
+
(
|
| 94 |
+
"big_bird",
|
| 95 |
+
(
|
| 96 |
+
"BigBirdTokenizer" if is_sentencepiece_available() else None,
|
| 97 |
+
"BigBirdTokenizerFast" if is_tokenizers_available() else None,
|
| 98 |
+
),
|
| 99 |
+
),
|
| 100 |
+
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
|
| 101 |
+
("biogpt", ("BioGptTokenizer", None)),
|
| 102 |
+
("bitnet", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 103 |
+
("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
|
| 104 |
+
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
|
| 105 |
+
("blip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 106 |
+
("blip-2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 107 |
+
("bloom", (None, "BloomTokenizerFast" if is_tokenizers_available() else None)),
|
| 108 |
+
("blt", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 109 |
+
("bridgetower", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
| 110 |
+
("bros", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 111 |
+
("byt5", ("ByT5Tokenizer", None)),
|
| 112 |
+
(
|
| 113 |
+
"camembert",
|
| 114 |
+
(
|
| 115 |
+
"CamembertTokenizer" if is_sentencepiece_available() else None,
|
| 116 |
+
"CamembertTokenizerFast" if is_tokenizers_available() else None,
|
| 117 |
+
),
|
| 118 |
+
),
|
| 119 |
+
("canine", ("CanineTokenizer", None)),
|
| 120 |
+
(
|
| 121 |
+
"chameleon",
|
| 122 |
+
(
|
| 123 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 124 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 125 |
+
),
|
| 126 |
+
),
|
| 127 |
+
("chinese_clip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 128 |
+
(
|
| 129 |
+
"clap",
|
| 130 |
+
(
|
| 131 |
+
"RobertaTokenizer",
|
| 132 |
+
"RobertaTokenizerFast" if is_tokenizers_available() else None,
|
| 133 |
+
),
|
| 134 |
+
),
|
| 135 |
+
(
|
| 136 |
+
"clip",
|
| 137 |
+
(
|
| 138 |
+
"CLIPTokenizer",
|
| 139 |
+
"CLIPTokenizerFast" if is_tokenizers_available() else None,
|
| 140 |
+
),
|
| 141 |
+
),
|
| 142 |
+
(
|
| 143 |
+
"clipseg",
|
| 144 |
+
(
|
| 145 |
+
"CLIPTokenizer",
|
| 146 |
+
"CLIPTokenizerFast" if is_tokenizers_available() else None,
|
| 147 |
+
),
|
| 148 |
+
),
|
| 149 |
+
("clvp", ("ClvpTokenizer", None)),
|
| 150 |
+
(
|
| 151 |
+
"code_llama",
|
| 152 |
+
(
|
| 153 |
+
"CodeLlamaTokenizer" if is_sentencepiece_available() else None,
|
| 154 |
+
"CodeLlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 155 |
+
),
|
| 156 |
+
),
|
| 157 |
+
("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
|
| 158 |
+
("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
| 159 |
+
("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
|
| 160 |
+
("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 161 |
+
("colqwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 162 |
+
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
|
| 163 |
+
(
|
| 164 |
+
"cpm",
|
| 165 |
+
(
|
| 166 |
+
"CpmTokenizer" if is_sentencepiece_available() else None,
|
| 167 |
+
"CpmTokenizerFast" if is_tokenizers_available() else None,
|
| 168 |
+
),
|
| 169 |
+
),
|
| 170 |
+
("cpmant", ("CpmAntTokenizer", None)),
|
| 171 |
+
("csm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 172 |
+
("ctrl", ("CTRLTokenizer", None)),
|
| 173 |
+
("data2vec-audio", ("Wav2Vec2CTCTokenizer", None)),
|
| 174 |
+
("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
| 175 |
+
("dbrx", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 176 |
+
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
|
| 177 |
+
(
|
| 178 |
+
"deberta-v2",
|
| 179 |
+
(
|
| 180 |
+
"DebertaV2Tokenizer" if is_sentencepiece_available() else None,
|
| 181 |
+
"DebertaV2TokenizerFast" if is_tokenizers_available() else None,
|
| 182 |
+
),
|
| 183 |
+
),
|
| 184 |
+
(
|
| 185 |
+
"deepseek_v2",
|
| 186 |
+
(
|
| 187 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 188 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 189 |
+
),
|
| 190 |
+
),
|
| 191 |
+
(
|
| 192 |
+
"deepseek_v3",
|
| 193 |
+
(
|
| 194 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 195 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 196 |
+
),
|
| 197 |
+
),
|
| 198 |
+
(
|
| 199 |
+
"deepseek_vl",
|
| 200 |
+
(
|
| 201 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 202 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 203 |
+
),
|
| 204 |
+
),
|
| 205 |
+
(
|
| 206 |
+
"deepseek_vl_hybrid",
|
| 207 |
+
(
|
| 208 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 209 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 210 |
+
),
|
| 211 |
+
),
|
| 212 |
+
("dia", ("DiaTokenizer", None)),
|
| 213 |
+
(
|
| 214 |
+
"diffllama",
|
| 215 |
+
(
|
| 216 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 217 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 218 |
+
),
|
| 219 |
+
),
|
| 220 |
+
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
|
| 221 |
+
(
|
| 222 |
+
"dpr",
|
| 223 |
+
(
|
| 224 |
+
"DPRQuestionEncoderTokenizer",
|
| 225 |
+
"DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
|
| 226 |
+
),
|
| 227 |
+
),
|
| 228 |
+
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
|
| 229 |
+
("emu3", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 230 |
+
("ernie", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 231 |
+
("ernie4_5", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 232 |
+
("ernie4_5_moe", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 233 |
+
("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
|
| 234 |
+
("esm", ("EsmTokenizer", None)),
|
| 235 |
+
(
|
| 236 |
+
"exaone4",
|
| 237 |
+
(
|
| 238 |
+
"GPT2Tokenizer" if is_tokenizers_available() else None,
|
| 239 |
+
"GPT2TokenizerFast" if is_tokenizers_available() else None,
|
| 240 |
+
),
|
| 241 |
+
),
|
| 242 |
+
("falcon", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 243 |
+
("falcon_mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 244 |
+
(
|
| 245 |
+
"fastspeech2_conformer",
|
| 246 |
+
("FastSpeech2ConformerTokenizer" if is_g2p_en_available() else None, None),
|
| 247 |
+
),
|
| 248 |
+
("flaubert", ("FlaubertTokenizer", None)),
|
| 249 |
+
("flex_olmo", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 250 |
+
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
|
| 251 |
+
("fsmt", ("FSMTTokenizer", None)),
|
| 252 |
+
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
|
| 253 |
+
(
|
| 254 |
+
"gemma",
|
| 255 |
+
(
|
| 256 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 257 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 258 |
+
),
|
| 259 |
+
),
|
| 260 |
+
(
|
| 261 |
+
"gemma2",
|
| 262 |
+
(
|
| 263 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 264 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 265 |
+
),
|
| 266 |
+
),
|
| 267 |
+
(
|
| 268 |
+
"gemma3",
|
| 269 |
+
(
|
| 270 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 271 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 272 |
+
),
|
| 273 |
+
),
|
| 274 |
+
(
|
| 275 |
+
"gemma3_text",
|
| 276 |
+
(
|
| 277 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 278 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 279 |
+
),
|
| 280 |
+
),
|
| 281 |
+
(
|
| 282 |
+
"gemma3n",
|
| 283 |
+
(
|
| 284 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 285 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 286 |
+
),
|
| 287 |
+
),
|
| 288 |
+
(
|
| 289 |
+
"gemma3n_text",
|
| 290 |
+
(
|
| 291 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 292 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 293 |
+
),
|
| 294 |
+
),
|
| 295 |
+
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 296 |
+
("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 297 |
+
("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 298 |
+
("glm4_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 299 |
+
("glm4v", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 300 |
+
("glm4v_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 301 |
+
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
|
| 302 |
+
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 303 |
+
("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 304 |
+
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 305 |
+
("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 306 |
+
("gpt_neox_japanese", ("GPTNeoXJapaneseTokenizer", None)),
|
| 307 |
+
("gpt_oss", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 308 |
+
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 309 |
+
("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)),
|
| 310 |
+
("granite", ("GPT2Tokenizer", None)),
|
| 311 |
+
("granitemoe", ("GPT2Tokenizer", None)),
|
| 312 |
+
("granitemoehybrid", ("GPT2Tokenizer", None)),
|
| 313 |
+
("granitemoeshared", ("GPT2Tokenizer", None)),
|
| 314 |
+
("grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 315 |
+
("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
| 316 |
+
("helium", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 317 |
+
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
|
| 318 |
+
("hubert", ("Wav2Vec2CTCTokenizer", None)),
|
| 319 |
+
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
| 320 |
+
("idefics", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 321 |
+
("idefics2", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 322 |
+
("idefics3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 323 |
+
("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 324 |
+
("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 325 |
+
("internvl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 326 |
+
(
|
| 327 |
+
"jamba",
|
| 328 |
+
(
|
| 329 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 330 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 331 |
+
),
|
| 332 |
+
),
|
| 333 |
+
("janus", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 334 |
+
(
|
| 335 |
+
"jetmoe",
|
| 336 |
+
(
|
| 337 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 338 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 339 |
+
),
|
| 340 |
+
),
|
| 341 |
+
("jukebox", ("JukeboxTokenizer", None)),
|
| 342 |
+
(
|
| 343 |
+
"kosmos-2",
|
| 344 |
+
(
|
| 345 |
+
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
|
| 346 |
+
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
|
| 347 |
+
),
|
| 348 |
+
),
|
| 349 |
+
("kosmos-2.5", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 350 |
+
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
|
| 351 |
+
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
|
| 352 |
+
("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
|
| 353 |
+
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
|
| 354 |
+
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
|
| 355 |
+
("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
|
| 356 |
+
(
|
| 357 |
+
"llama",
|
| 358 |
+
(
|
| 359 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 360 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 361 |
+
),
|
| 362 |
+
),
|
| 363 |
+
(
|
| 364 |
+
"llama4",
|
| 365 |
+
(
|
| 366 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 367 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 368 |
+
),
|
| 369 |
+
),
|
| 370 |
+
(
|
| 371 |
+
"llama4_text",
|
| 372 |
+
(
|
| 373 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 374 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 375 |
+
),
|
| 376 |
+
),
|
| 377 |
+
("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 378 |
+
("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 379 |
+
("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 380 |
+
("llava_onevision", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 381 |
+
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
|
| 382 |
+
(
|
| 383 |
+
"longt5",
|
| 384 |
+
(
|
| 385 |
+
"T5Tokenizer" if is_sentencepiece_available() else None,
|
| 386 |
+
"T5TokenizerFast" if is_tokenizers_available() else None,
|
| 387 |
+
),
|
| 388 |
+
),
|
| 389 |
+
("luke", ("LukeTokenizer", None)),
|
| 390 |
+
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
|
| 391 |
+
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
|
| 392 |
+
("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 393 |
+
("mamba2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 394 |
+
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
|
| 395 |
+
(
|
| 396 |
+
"mbart",
|
| 397 |
+
(
|
| 398 |
+
"MBartTokenizer" if is_sentencepiece_available() else None,
|
| 399 |
+
"MBartTokenizerFast" if is_tokenizers_available() else None,
|
| 400 |
+
),
|
| 401 |
+
),
|
| 402 |
+
(
|
| 403 |
+
"mbart50",
|
| 404 |
+
(
|
| 405 |
+
"MBart50Tokenizer" if is_sentencepiece_available() else None,
|
| 406 |
+
"MBart50TokenizerFast" if is_tokenizers_available() else None,
|
| 407 |
+
),
|
| 408 |
+
),
|
| 409 |
+
("mega", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
| 410 |
+
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 411 |
+
(
|
| 412 |
+
"metaclip_2",
|
| 413 |
+
(
|
| 414 |
+
"XLMRobertaTokenizer",
|
| 415 |
+
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
|
| 416 |
+
),
|
| 417 |
+
),
|
| 418 |
+
("mgp-str", ("MgpstrTokenizer", None)),
|
| 419 |
+
(
|
| 420 |
+
"minimax",
|
| 421 |
+
(
|
| 422 |
+
"GPT2Tokenizer" if is_sentencepiece_available() else None,
|
| 423 |
+
"GPT2TokenizerFast" if is_tokenizers_available() else None,
|
| 424 |
+
),
|
| 425 |
+
),
|
| 426 |
+
(
|
| 427 |
+
"ministral",
|
| 428 |
+
(
|
| 429 |
+
"MistralCommonTokenizer"
|
| 430 |
+
if is_mistral_common_available()
|
| 431 |
+
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
|
| 432 |
+
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
| 433 |
+
),
|
| 434 |
+
),
|
| 435 |
+
(
|
| 436 |
+
"mistral",
|
| 437 |
+
(
|
| 438 |
+
"MistralCommonTokenizer"
|
| 439 |
+
if is_mistral_common_available()
|
| 440 |
+
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
|
| 441 |
+
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
| 442 |
+
),
|
| 443 |
+
),
|
| 444 |
+
(
|
| 445 |
+
"mistral3",
|
| 446 |
+
(
|
| 447 |
+
"MistralCommonTokenizer"
|
| 448 |
+
if is_mistral_common_available()
|
| 449 |
+
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
|
| 450 |
+
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
| 451 |
+
),
|
| 452 |
+
),
|
| 453 |
+
(
|
| 454 |
+
"mixtral",
|
| 455 |
+
(
|
| 456 |
+
"MistralCommonTokenizer"
|
| 457 |
+
if is_mistral_common_available()
|
| 458 |
+
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
|
| 459 |
+
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
| 460 |
+
),
|
| 461 |
+
),
|
| 462 |
+
("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 463 |
+
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
|
| 464 |
+
("mm-grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 465 |
+
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
|
| 466 |
+
("modernbert", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 467 |
+
("moonshine", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 468 |
+
("moshi", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 469 |
+
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
|
| 470 |
+
("mpt", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 471 |
+
("mra", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
| 472 |
+
(
|
| 473 |
+
"mt5",
|
| 474 |
+
(
|
| 475 |
+
"MT5Tokenizer" if is_sentencepiece_available() else None,
|
| 476 |
+
"MT5TokenizerFast" if is_tokenizers_available() else None,
|
| 477 |
+
),
|
| 478 |
+
),
|
| 479 |
+
("musicgen", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
| 480 |
+
("musicgen_melody", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
| 481 |
+
("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)),
|
| 482 |
+
("myt5", ("MyT5Tokenizer", None)),
|
| 483 |
+
("nemotron", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 484 |
+
("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 485 |
+
(
|
| 486 |
+
"nllb",
|
| 487 |
+
(
|
| 488 |
+
"NllbTokenizer" if is_sentencepiece_available() else None,
|
| 489 |
+
"NllbTokenizerFast" if is_tokenizers_available() else None,
|
| 490 |
+
),
|
| 491 |
+
),
|
| 492 |
+
(
|
| 493 |
+
"nllb-moe",
|
| 494 |
+
(
|
| 495 |
+
"NllbTokenizer" if is_sentencepiece_available() else None,
|
| 496 |
+
"NllbTokenizerFast" if is_tokenizers_available() else None,
|
| 497 |
+
),
|
| 498 |
+
),
|
| 499 |
+
(
|
| 500 |
+
"nystromformer",
|
| 501 |
+
(
|
| 502 |
+
"AlbertTokenizer" if is_sentencepiece_available() else None,
|
| 503 |
+
"AlbertTokenizerFast" if is_tokenizers_available() else None,
|
| 504 |
+
),
|
| 505 |
+
),
|
| 506 |
+
("olmo", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 507 |
+
("olmo2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 508 |
+
("olmo3", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 509 |
+
("olmoe", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 510 |
+
(
|
| 511 |
+
"omdet-turbo",
|
| 512 |
+
("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None),
|
| 513 |
+
),
|
| 514 |
+
("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
| 515 |
+
(
|
| 516 |
+
"openai-gpt",
|
| 517 |
+
("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None),
|
| 518 |
+
),
|
| 519 |
+
("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 520 |
+
("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
| 521 |
+
("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
| 522 |
+
("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 523 |
+
("parakeet", ("ParakeetCTCTokenizer", None)),
|
| 524 |
+
(
|
| 525 |
+
"pegasus",
|
| 526 |
+
(
|
| 527 |
+
"PegasusTokenizer" if is_sentencepiece_available() else None,
|
| 528 |
+
"PegasusTokenizerFast" if is_tokenizers_available() else None,
|
| 529 |
+
),
|
| 530 |
+
),
|
| 531 |
+
(
|
| 532 |
+
"pegasus_x",
|
| 533 |
+
(
|
| 534 |
+
"PegasusTokenizer" if is_sentencepiece_available() else None,
|
| 535 |
+
"PegasusTokenizerFast" if is_tokenizers_available() else None,
|
| 536 |
+
),
|
| 537 |
+
),
|
| 538 |
+
(
|
| 539 |
+
"perceiver",
|
| 540 |
+
(
|
| 541 |
+
"PerceiverTokenizer",
|
| 542 |
+
None,
|
| 543 |
+
),
|
| 544 |
+
),
|
| 545 |
+
(
|
| 546 |
+
"persimmon",
|
| 547 |
+
(
|
| 548 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 549 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 550 |
+
),
|
| 551 |
+
),
|
| 552 |
+
("phi", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
|
| 553 |
+
("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 554 |
+
("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 555 |
+
("phobert", ("PhobertTokenizer", None)),
|
| 556 |
+
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
| 557 |
+
(
|
| 558 |
+
"pixtral",
|
| 559 |
+
(
|
| 560 |
+
None,
|
| 561 |
+
"MistralCommonTokenizer"
|
| 562 |
+
if is_mistral_common_available()
|
| 563 |
+
else ("PreTrainedTokenizerFast" if is_tokenizers_available() else None),
|
| 564 |
+
),
|
| 565 |
+
),
|
| 566 |
+
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
|
| 567 |
+
("prophetnet", ("ProphetNetTokenizer", None)),
|
| 568 |
+
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 569 |
+
(
|
| 570 |
+
"qwen2",
|
| 571 |
+
(
|
| 572 |
+
"Qwen2Tokenizer",
|
| 573 |
+
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
|
| 574 |
+
),
|
| 575 |
+
),
|
| 576 |
+
("qwen2_5_omni", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 577 |
+
("qwen2_5_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 578 |
+
("qwen2_audio", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 579 |
+
(
|
| 580 |
+
"qwen2_moe",
|
| 581 |
+
(
|
| 582 |
+
"Qwen2Tokenizer",
|
| 583 |
+
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
|
| 584 |
+
),
|
| 585 |
+
),
|
| 586 |
+
("qwen2_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 587 |
+
(
|
| 588 |
+
"qwen3",
|
| 589 |
+
(
|
| 590 |
+
"Qwen2Tokenizer",
|
| 591 |
+
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
|
| 592 |
+
),
|
| 593 |
+
),
|
| 594 |
+
(
|
| 595 |
+
"qwen3_moe",
|
| 596 |
+
(
|
| 597 |
+
"Qwen2Tokenizer",
|
| 598 |
+
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
|
| 599 |
+
),
|
| 600 |
+
),
|
| 601 |
+
(
|
| 602 |
+
"qwen3_next",
|
| 603 |
+
(
|
| 604 |
+
"Qwen2Tokenizer",
|
| 605 |
+
"Qwen2TokenizerFast" if is_tokenizers_available() else None,
|
| 606 |
+
),
|
| 607 |
+
),
|
| 608 |
+
("qwen3_omni_moe", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 609 |
+
("qwen3_vl", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 610 |
+
("qwen3_vl_moe", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
|
| 611 |
+
("rag", ("RagTokenizer", None)),
|
| 612 |
+
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
|
| 613 |
+
(
|
| 614 |
+
"recurrent_gemma",
|
| 615 |
+
(
|
| 616 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 617 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 618 |
+
),
|
| 619 |
+
),
|
| 620 |
+
(
|
| 621 |
+
"reformer",
|
| 622 |
+
(
|
| 623 |
+
"ReformerTokenizer" if is_sentencepiece_available() else None,
|
| 624 |
+
"ReformerTokenizerFast" if is_tokenizers_available() else None,
|
| 625 |
+
),
|
| 626 |
+
),
|
| 627 |
+
(
|
| 628 |
+
"rembert",
|
| 629 |
+
(
|
| 630 |
+
"RemBertTokenizer" if is_sentencepiece_available() else None,
|
| 631 |
+
"RemBertTokenizerFast" if is_tokenizers_available() else None,
|
| 632 |
+
),
|
| 633 |
+
),
|
| 634 |
+
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
|
| 635 |
+
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
| 636 |
+
(
|
| 637 |
+
"roberta-prelayernorm",
|
| 638 |
+
("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None),
|
| 639 |
+
),
|
| 640 |
+
("roc_bert", ("RoCBertTokenizer", None)),
|
| 641 |
+
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
|
| 642 |
+
("rwkv", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 643 |
+
(
|
| 644 |
+
"seamless_m4t",
|
| 645 |
+
(
|
| 646 |
+
"SeamlessM4TTokenizer" if is_sentencepiece_available() else None,
|
| 647 |
+
"SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
|
| 648 |
+
),
|
| 649 |
+
),
|
| 650 |
+
(
|
| 651 |
+
"seamless_m4t_v2",
|
| 652 |
+
(
|
| 653 |
+
"SeamlessM4TTokenizer" if is_sentencepiece_available() else None,
|
| 654 |
+
"SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
|
| 655 |
+
),
|
| 656 |
+
),
|
| 657 |
+
(
|
| 658 |
+
"shieldgemma2",
|
| 659 |
+
(
|
| 660 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 661 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 662 |
+
),
|
| 663 |
+
),
|
| 664 |
+
("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)),
|
| 665 |
+
(
|
| 666 |
+
"siglip2",
|
| 667 |
+
(
|
| 668 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 669 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 670 |
+
),
|
| 671 |
+
),
|
| 672 |
+
("smollm3", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
| 673 |
+
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
|
| 674 |
+
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
|
| 675 |
+
("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),
|
| 676 |
+
("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
|
| 677 |
+
(
|
| 678 |
+
"squeezebert",
|
| 679 |
+
("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
|
| 680 |
+
),
|
| 681 |
+
("stablelm", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 682 |
+
("starcoder2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
| 683 |
+
(
|
| 684 |
+
"switch_transformers",
|
| 685 |
+
(
|
| 686 |
+
"T5Tokenizer" if is_sentencepiece_available() else None,
|
| 687 |
+
"T5TokenizerFast" if is_tokenizers_available() else None,
|
| 688 |
+
),
|
| 689 |
+
),
|
| 690 |
+
(
|
| 691 |
+
"t5",
|
| 692 |
+
(
|
| 693 |
+
"T5Tokenizer" if is_sentencepiece_available() else None,
|
| 694 |
+
"T5TokenizerFast" if is_tokenizers_available() else None,
|
| 695 |
+
),
|
| 696 |
+
),
|
| 697 |
+
(
|
| 698 |
+
"t5gemma",
|
| 699 |
+
(
|
| 700 |
+
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
| 701 |
+
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
| 702 |
+
),
|
| 703 |
+
),
|
| 704 |
+
("tapas", ("TapasTokenizer", None)),
|
| 705 |
+
("tapex", ("TapexTokenizer", None)),
|
| 706 |
+
("transfo-xl", ("TransfoXLTokenizer", None)),
|
| 707 |
+
("tvp", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 708 |
+
(
|
| 709 |
+
"udop",
|
| 710 |
+
(
|
| 711 |
+
"UdopTokenizer" if is_sentencepiece_available() else None,
|
| 712 |
+
"UdopTokenizerFast" if is_tokenizers_available() else None,
|
| 713 |
+
),
|
| 714 |
+
),
|
| 715 |
+
(
|
| 716 |
+
"umt5",
|
| 717 |
+
(
|
| 718 |
+
"T5Tokenizer" if is_sentencepiece_available() else None,
|
| 719 |
+
"T5TokenizerFast" if is_tokenizers_available() else None,
|
| 720 |
+
),
|
| 721 |
+
),
|
| 722 |
+
("video_llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 723 |
+
("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 724 |
+
("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
| 725 |
+
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
| 726 |
+
("vits", ("VitsTokenizer", None)),
|
| 727 |
+
(
|
| 728 |
+
"voxtral",
|
| 729 |
+
(
|
| 730 |
+
"MistralCommonTokenizer" if is_mistral_common_available() else None,
|
| 731 |
+
"PreTrainedTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
|
| 732 |
+
),
|
| 733 |
+
),
|
| 734 |
+
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
|
| 735 |
+
("wav2vec2-bert", ("Wav2Vec2CTCTokenizer", None)),
|
| 736 |
+
("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
|
| 737 |
+
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
|
| 738 |
+
("whisper", ("WhisperTokenizer", "WhisperTokenizerFast" if is_tokenizers_available() else None)),
|
| 739 |
+
("xclip", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
| 740 |
+
(
|
| 741 |
+
"xglm",
|
| 742 |
+
(
|
| 743 |
+
"XGLMTokenizer" if is_sentencepiece_available() else None,
|
| 744 |
+
"XGLMTokenizerFast" if is_tokenizers_available() else None,
|
| 745 |
+
),
|
| 746 |
+
),
|
| 747 |
+
("xlm", ("XLMTokenizer", None)),
|
| 748 |
+
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
|
| 749 |
+
(
|
| 750 |
+
"xlm-roberta",
|
| 751 |
+
(
|
| 752 |
+
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
|
| 753 |
+
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
|
| 754 |
+
),
|
| 755 |
+
),
|
| 756 |
+
(
|
| 757 |
+
"xlm-roberta-xl",
|
| 758 |
+
(
|
| 759 |
+
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
|
| 760 |
+
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
|
| 761 |
+
),
|
| 762 |
+
),
|
| 763 |
+
(
|
| 764 |
+
"xlnet",
|
| 765 |
+
(
|
| 766 |
+
"XLNetTokenizer" if is_sentencepiece_available() else None,
|
| 767 |
+
"XLNetTokenizerFast" if is_tokenizers_available() else None,
|
| 768 |
+
),
|
| 769 |
+
),
|
| 770 |
+
("xlstm", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
| 771 |
+
(
|
| 772 |
+
"xmod",
|
| 773 |
+
(
|
| 774 |
+
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
|
| 775 |
+
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
|
| 776 |
+
),
|
| 777 |
+
),
|
| 778 |
+
(
|
| 779 |
+
"yoso",
|
| 780 |
+
(
|
| 781 |
+
"AlbertTokenizer" if is_sentencepiece_available() else None,
|
| 782 |
+
"AlbertTokenizerFast" if is_tokenizers_available() else None,
|
| 783 |
+
),
|
| 784 |
+
),
|
| 785 |
+
(
|
| 786 |
+
"zamba",
|
| 787 |
+
(
|
| 788 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 789 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 790 |
+
),
|
| 791 |
+
),
|
| 792 |
+
(
|
| 793 |
+
"zamba2",
|
| 794 |
+
(
|
| 795 |
+
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
| 796 |
+
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
| 797 |
+
),
|
| 798 |
+
),
|
| 799 |
+
]
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
|
| 803 |
+
|
| 804 |
+
CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]:
|
| 808 |
+
if class_name == "PreTrainedTokenizerFast":
|
| 809 |
+
return PreTrainedTokenizerFast
|
| 810 |
+
|
| 811 |
+
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
|
| 812 |
+
if class_name in tokenizers:
|
| 813 |
+
module_name = model_type_to_module_name(module_name)
|
| 814 |
+
if module_name in ["mistral", "mixtral", "ministral"] and class_name == "MistralCommonTokenizer":
|
| 815 |
+
module = importlib.import_module(".tokenization_mistral_common", "transformers")
|
| 816 |
+
else:
|
| 817 |
+
module = importlib.import_module(f".{module_name}", "transformers.models")
|
| 818 |
+
try:
|
| 819 |
+
return getattr(module, class_name)
|
| 820 |
+
except AttributeError:
|
| 821 |
+
continue
|
| 822 |
+
|
| 823 |
+
for tokenizers in TOKENIZER_MAPPING._extra_content.values():
|
| 824 |
+
for tokenizer in tokenizers:
|
| 825 |
+
if getattr(tokenizer, "__name__", None) == class_name:
|
| 826 |
+
return tokenizer
|
| 827 |
+
|
| 828 |
+
# We did not fine the class, but maybe it's because a dep is missing. In that case, the class will be in the main
|
| 829 |
+
# init and we return the proper dummy to get an appropriate error message.
|
| 830 |
+
main_module = importlib.import_module("transformers")
|
| 831 |
+
if hasattr(main_module, class_name):
|
| 832 |
+
return getattr(main_module, class_name)
|
| 833 |
+
|
| 834 |
+
return None
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
def get_tokenizer_config(
|
| 838 |
+
pretrained_model_name_or_path: Union[str, os.PathLike[str]],
|
| 839 |
+
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
|
| 840 |
+
force_download: bool = False,
|
| 841 |
+
resume_download: Optional[bool] = None,
|
| 842 |
+
proxies: Optional[dict[str, str]] = None,
|
| 843 |
+
token: Optional[Union[bool, str]] = None,
|
| 844 |
+
revision: Optional[str] = None,
|
| 845 |
+
local_files_only: bool = False,
|
| 846 |
+
subfolder: str = "",
|
| 847 |
+
**kwargs,
|
| 848 |
+
) -> dict[str, Any]:
|
| 849 |
+
"""
|
| 850 |
+
Loads the tokenizer configuration from a pretrained model tokenizer configuration.
|
| 851 |
+
|
| 852 |
+
Args:
|
| 853 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 854 |
+
This can be either:
|
| 855 |
+
|
| 856 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
| 857 |
+
huggingface.co.
|
| 858 |
+
- a path to a *directory* containing a configuration file saved using the
|
| 859 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
| 860 |
+
|
| 861 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 862 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
| 863 |
+
cache should not be used.
|
| 864 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 865 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
| 866 |
+
exist.
|
| 867 |
+
resume_download:
|
| 868 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 869 |
+
Will be removed in v5 of Transformers.
|
| 870 |
+
proxies (`dict[str, str]`, *optional*):
|
| 871 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 872 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 873 |
+
token (`str` or *bool*, *optional*):
|
| 874 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 875 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 876 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 877 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 878 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 879 |
+
identifier allowed by git.
|
| 880 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 881 |
+
If `True`, will only try to load the tokenizer configuration from local files.
|
| 882 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
| 883 |
+
In case the tokenizer config is located inside a subfolder of the model repo on huggingface.co, you can
|
| 884 |
+
specify the folder name here.
|
| 885 |
+
|
| 886 |
+
<Tip>
|
| 887 |
+
|
| 888 |
+
Passing `token=True` is required when you want to use a private model.
|
| 889 |
+
|
| 890 |
+
</Tip>
|
| 891 |
+
|
| 892 |
+
Returns:
|
| 893 |
+
`dict`: The configuration of the tokenizer.
|
| 894 |
+
|
| 895 |
+
Examples:
|
| 896 |
+
|
| 897 |
+
```python
|
| 898 |
+
# Download configuration from huggingface.co and cache.
|
| 899 |
+
tokenizer_config = get_tokenizer_config("google-bert/bert-base-uncased")
|
| 900 |
+
# This model does not have a tokenizer config so the result will be an empty dict.
|
| 901 |
+
tokenizer_config = get_tokenizer_config("FacebookAI/xlm-roberta-base")
|
| 902 |
+
|
| 903 |
+
# Save a pretrained tokenizer locally and you can reload its config
|
| 904 |
+
from transformers import AutoTokenizer
|
| 905 |
+
|
| 906 |
+
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
|
| 907 |
+
tokenizer.save_pretrained("tokenizer-test")
|
| 908 |
+
tokenizer_config = get_tokenizer_config("tokenizer-test")
|
| 909 |
+
```"""
|
| 910 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 911 |
+
if use_auth_token is not None:
|
| 912 |
+
warnings.warn(
|
| 913 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 914 |
+
FutureWarning,
|
| 915 |
+
)
|
| 916 |
+
if token is not None:
|
| 917 |
+
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
| 918 |
+
token = use_auth_token
|
| 919 |
+
|
| 920 |
+
commit_hash = kwargs.get("_commit_hash")
|
| 921 |
+
resolved_config_file = cached_file(
|
| 922 |
+
pretrained_model_name_or_path,
|
| 923 |
+
TOKENIZER_CONFIG_FILE,
|
| 924 |
+
cache_dir=cache_dir,
|
| 925 |
+
force_download=force_download,
|
| 926 |
+
resume_download=resume_download,
|
| 927 |
+
proxies=proxies,
|
| 928 |
+
token=token,
|
| 929 |
+
revision=revision,
|
| 930 |
+
local_files_only=local_files_only,
|
| 931 |
+
subfolder=subfolder,
|
| 932 |
+
_raise_exceptions_for_gated_repo=False,
|
| 933 |
+
_raise_exceptions_for_missing_entries=False,
|
| 934 |
+
_raise_exceptions_for_connection_errors=False,
|
| 935 |
+
_commit_hash=commit_hash,
|
| 936 |
+
)
|
| 937 |
+
if resolved_config_file is None:
|
| 938 |
+
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
|
| 939 |
+
return {}
|
| 940 |
+
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
| 941 |
+
|
| 942 |
+
with open(resolved_config_file, encoding="utf-8") as reader:
|
| 943 |
+
result = json.load(reader)
|
| 944 |
+
result["_commit_hash"] = commit_hash
|
| 945 |
+
return result
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
class AutoTokenizer:
|
| 949 |
+
r"""
|
| 950 |
+
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when
|
| 951 |
+
created with the [`AutoTokenizer.from_pretrained`] class method.
|
| 952 |
+
|
| 953 |
+
This class cannot be instantiated directly using `__init__()` (throws an error).
|
| 954 |
+
"""
|
| 955 |
+
|
| 956 |
+
def __init__(self):
|
| 957 |
+
raise OSError(
|
| 958 |
+
"AutoTokenizer is designed to be instantiated "
|
| 959 |
+
"using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method."
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
@classmethod
|
| 963 |
+
@replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES)
|
| 964 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
| 965 |
+
r"""
|
| 966 |
+
Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.
|
| 967 |
+
|
| 968 |
+
The tokenizer class to instantiate is selected based on the `model_type` property of the config object (either
|
| 969 |
+
passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
|
| 970 |
+
falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 971 |
+
|
| 972 |
+
List options
|
| 973 |
+
|
| 974 |
+
Params:
|
| 975 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 976 |
+
Can be either:
|
| 977 |
+
|
| 978 |
+
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
|
| 979 |
+
- A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
|
| 980 |
+
using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
| 981 |
+
- A path or url to a single saved vocabulary file if and only if the tokenizer only requires a
|
| 982 |
+
single vocabulary file (like Bert or XLNet), e.g.: `./my_model_directory/vocab.txt`. (Not
|
| 983 |
+
applicable to all derived classes)
|
| 984 |
+
inputs (additional positional arguments, *optional*):
|
| 985 |
+
Will be passed along to the Tokenizer `__init__()` method.
|
| 986 |
+
config ([`PretrainedConfig`], *optional*)
|
| 987 |
+
The configuration object used to determine the tokenizer class to instantiate.
|
| 988 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 989 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
| 990 |
+
standard cache should not be used.
|
| 991 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 992 |
+
Whether or not to force the (re-)download the model weights and configuration files and override the
|
| 993 |
+
cached versions if they exist.
|
| 994 |
+
resume_download:
|
| 995 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 996 |
+
Will be removed in v5 of Transformers.
|
| 997 |
+
proxies (`dict[str, str]`, *optional*):
|
| 998 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 999 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 1000 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 1001 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 1002 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 1003 |
+
identifier allowed by git.
|
| 1004 |
+
subfolder (`str`, *optional*):
|
| 1005 |
+
In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
|
| 1006 |
+
facebook/rag-token-base), specify it here.
|
| 1007 |
+
use_fast (`bool`, *optional*, defaults to `True`):
|
| 1008 |
+
Use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported for
|
| 1009 |
+
a given model. If a fast tokenizer is not available for a given model, a normal Python-based tokenizer
|
| 1010 |
+
is returned instead.
|
| 1011 |
+
tokenizer_type (`str`, *optional*):
|
| 1012 |
+
Tokenizer type to be loaded.
|
| 1013 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 1014 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 1015 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 1016 |
+
execute code present on the Hub on your local machine.
|
| 1017 |
+
kwargs (additional keyword arguments, *optional*):
|
| 1018 |
+
Will be passed to the Tokenizer `__init__()` method. Can be used to set special tokens like
|
| 1019 |
+
`bos_token`, `eos_token`, `unk_token`, `sep_token`, `pad_token`, `cls_token`, `mask_token`,
|
| 1020 |
+
`additional_special_tokens`. See parameters in the `__init__()` for more details.
|
| 1021 |
+
|
| 1022 |
+
Examples:
|
| 1023 |
+
|
| 1024 |
+
```python
|
| 1025 |
+
>>> from transformers import AutoTokenizer
|
| 1026 |
+
|
| 1027 |
+
>>> # Download vocabulary from huggingface.co and cache.
|
| 1028 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
| 1029 |
+
|
| 1030 |
+
>>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
|
| 1031 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
|
| 1032 |
+
|
| 1033 |
+
>>> # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
|
| 1034 |
+
>>> # tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/")
|
| 1035 |
+
|
| 1036 |
+
>>> # Download vocabulary from huggingface.co and define model-specific arguments
|
| 1037 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base", add_prefix_space=True)
|
| 1038 |
+
```"""
|
| 1039 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 1040 |
+
if use_auth_token is not None:
|
| 1041 |
+
warnings.warn(
|
| 1042 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 1043 |
+
FutureWarning,
|
| 1044 |
+
)
|
| 1045 |
+
if kwargs.get("token") is not None:
|
| 1046 |
+
raise ValueError(
|
| 1047 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 1048 |
+
)
|
| 1049 |
+
kwargs["token"] = use_auth_token
|
| 1050 |
+
|
| 1051 |
+
config = kwargs.pop("config", None)
|
| 1052 |
+
kwargs["_from_auto"] = True
|
| 1053 |
+
|
| 1054 |
+
use_fast = kwargs.pop("use_fast", True)
|
| 1055 |
+
tokenizer_type = kwargs.pop("tokenizer_type", None)
|
| 1056 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 1057 |
+
gguf_file = kwargs.get("gguf_file")
|
| 1058 |
+
|
| 1059 |
+
# First, let's see whether the tokenizer_type is passed so that we can leverage it
|
| 1060 |
+
if tokenizer_type is not None:
|
| 1061 |
+
tokenizer_class = None
|
| 1062 |
+
tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None)
|
| 1063 |
+
|
| 1064 |
+
if tokenizer_class_tuple is None:
|
| 1065 |
+
raise ValueError(
|
| 1066 |
+
f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of "
|
| 1067 |
+
f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES)}."
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple
|
| 1071 |
+
|
| 1072 |
+
if use_fast:
|
| 1073 |
+
if tokenizer_fast_class_name is not None:
|
| 1074 |
+
tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name)
|
| 1075 |
+
else:
|
| 1076 |
+
logger.warning(
|
| 1077 |
+
"`use_fast` is set to `True` but the tokenizer class does not have a fast version. "
|
| 1078 |
+
" Falling back to the slow version."
|
| 1079 |
+
)
|
| 1080 |
+
if tokenizer_class is None:
|
| 1081 |
+
tokenizer_class = tokenizer_class_from_name(tokenizer_class_name)
|
| 1082 |
+
|
| 1083 |
+
if tokenizer_class is None:
|
| 1084 |
+
raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.")
|
| 1085 |
+
|
| 1086 |
+
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
| 1087 |
+
|
| 1088 |
+
# Next, let's try to use the tokenizer_config file to get the tokenizer class.
|
| 1089 |
+
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
| 1090 |
+
if "_commit_hash" in tokenizer_config:
|
| 1091 |
+
kwargs["_commit_hash"] = tokenizer_config["_commit_hash"]
|
| 1092 |
+
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
|
| 1093 |
+
tokenizer_auto_map = None
|
| 1094 |
+
if "auto_map" in tokenizer_config:
|
| 1095 |
+
if isinstance(tokenizer_config["auto_map"], (tuple, list)):
|
| 1096 |
+
# Legacy format for dynamic tokenizers
|
| 1097 |
+
tokenizer_auto_map = tokenizer_config["auto_map"]
|
| 1098 |
+
else:
|
| 1099 |
+
tokenizer_auto_map = tokenizer_config["auto_map"].get("AutoTokenizer", None)
|
| 1100 |
+
|
| 1101 |
+
# If that did not work, let's try to use the config.
|
| 1102 |
+
if config_tokenizer_class is None:
|
| 1103 |
+
if not isinstance(config, PretrainedConfig):
|
| 1104 |
+
if gguf_file:
|
| 1105 |
+
gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **kwargs)
|
| 1106 |
+
config_dict = load_gguf_checkpoint(gguf_path, return_tensors=False)["config"]
|
| 1107 |
+
config = AutoConfig.for_model(**config_dict)
|
| 1108 |
+
else:
|
| 1109 |
+
config = AutoConfig.from_pretrained(
|
| 1110 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 1111 |
+
)
|
| 1112 |
+
config_tokenizer_class = config.tokenizer_class
|
| 1113 |
+
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
|
| 1114 |
+
tokenizer_auto_map = config.auto_map["AutoTokenizer"]
|
| 1115 |
+
|
| 1116 |
+
has_remote_code = tokenizer_auto_map is not None
|
| 1117 |
+
has_local_code = type(config) in TOKENIZER_MAPPING or (
|
| 1118 |
+
config_tokenizer_class is not None
|
| 1119 |
+
and (
|
| 1120 |
+
tokenizer_class_from_name(config_tokenizer_class) is not None
|
| 1121 |
+
or tokenizer_class_from_name(config_tokenizer_class + "Fast") is not None
|
| 1122 |
+
)
|
| 1123 |
+
)
|
| 1124 |
+
if has_remote_code:
|
| 1125 |
+
if use_fast and tokenizer_auto_map[1] is not None:
|
| 1126 |
+
class_ref = tokenizer_auto_map[1]
|
| 1127 |
+
else:
|
| 1128 |
+
class_ref = tokenizer_auto_map[0]
|
| 1129 |
+
if "--" in class_ref:
|
| 1130 |
+
upstream_repo = class_ref.split("--")[0]
|
| 1131 |
+
else:
|
| 1132 |
+
upstream_repo = None
|
| 1133 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 1134 |
+
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
if has_remote_code and trust_remote_code:
|
| 1138 |
+
tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
|
| 1139 |
+
_ = kwargs.pop("code_revision", None)
|
| 1140 |
+
tokenizer_class.register_for_auto_class()
|
| 1141 |
+
return tokenizer_class.from_pretrained(
|
| 1142 |
+
pretrained_model_name_or_path, *inputs, trust_remote_code=trust_remote_code, **kwargs
|
| 1143 |
+
)
|
| 1144 |
+
elif config_tokenizer_class is not None:
|
| 1145 |
+
tokenizer_class = None
|
| 1146 |
+
if use_fast and not config_tokenizer_class.endswith("Fast"):
|
| 1147 |
+
tokenizer_class_candidate = f"{config_tokenizer_class}Fast"
|
| 1148 |
+
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
|
| 1149 |
+
if tokenizer_class is None:
|
| 1150 |
+
tokenizer_class_candidate = config_tokenizer_class
|
| 1151 |
+
tokenizer_class = tokenizer_class_from_name(tokenizer_class_candidate)
|
| 1152 |
+
if tokenizer_class is None:
|
| 1153 |
+
raise ValueError(
|
| 1154 |
+
f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
|
| 1155 |
+
)
|
| 1156 |
+
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
| 1157 |
+
|
| 1158 |
+
# Otherwise we have to be creative.
|
| 1159 |
+
# if model is an encoder decoder, the encoder tokenizer class is used by default
|
| 1160 |
+
if isinstance(config, EncoderDecoderConfig):
|
| 1161 |
+
if type(config.decoder) is not type(config.encoder):
|
| 1162 |
+
logger.warning(
|
| 1163 |
+
f"The encoder model config class: {config.encoder.__class__} is different from the decoder model "
|
| 1164 |
+
f"config class: {config.decoder.__class__}. It is not recommended to use the "
|
| 1165 |
+
"`AutoTokenizer.from_pretrained()` method in this case. Please use the encoder and decoder "
|
| 1166 |
+
"specific tokenizer classes."
|
| 1167 |
+
)
|
| 1168 |
+
config = config.encoder
|
| 1169 |
+
|
| 1170 |
+
model_type = config_class_to_model_type(type(config).__name__)
|
| 1171 |
+
if model_type is not None:
|
| 1172 |
+
tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
|
| 1173 |
+
|
| 1174 |
+
if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
|
| 1175 |
+
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
| 1176 |
+
else:
|
| 1177 |
+
if tokenizer_class_py is not None:
|
| 1178 |
+
return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
| 1179 |
+
else:
|
| 1180 |
+
raise ValueError(
|
| 1181 |
+
"This tokenizer cannot be instantiated. Please make sure you have `sentencepiece` installed "
|
| 1182 |
+
"in order to use this tokenizer."
|
| 1183 |
+
)
|
| 1184 |
+
|
| 1185 |
+
raise ValueError(
|
| 1186 |
+
f"Unrecognized configuration class {config.__class__} to build an AutoTokenizer.\n"
|
| 1187 |
+
f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING)}."
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
@staticmethod
|
| 1191 |
+
def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None, exist_ok=False):
|
| 1192 |
+
"""
|
| 1193 |
+
Register a new tokenizer in this mapping.
|
| 1194 |
+
|
| 1195 |
+
|
| 1196 |
+
Args:
|
| 1197 |
+
config_class ([`PretrainedConfig`]):
|
| 1198 |
+
The configuration corresponding to the model to register.
|
| 1199 |
+
slow_tokenizer_class ([`PretrainedTokenizer`], *optional*):
|
| 1200 |
+
The slow tokenizer to register.
|
| 1201 |
+
fast_tokenizer_class ([`PretrainedTokenizerFast`], *optional*):
|
| 1202 |
+
The fast tokenizer to register.
|
| 1203 |
+
"""
|
| 1204 |
+
if slow_tokenizer_class is None and fast_tokenizer_class is None:
|
| 1205 |
+
raise ValueError("You need to pass either a `slow_tokenizer_class` or a `fast_tokenizer_class")
|
| 1206 |
+
if slow_tokenizer_class is not None and issubclass(slow_tokenizer_class, PreTrainedTokenizerFast):
|
| 1207 |
+
raise ValueError("You passed a fast tokenizer in the `slow_tokenizer_class`.")
|
| 1208 |
+
if fast_tokenizer_class is not None and issubclass(fast_tokenizer_class, PreTrainedTokenizer):
|
| 1209 |
+
raise ValueError("You passed a slow tokenizer in the `fast_tokenizer_class`.")
|
| 1210 |
+
|
| 1211 |
+
if (
|
| 1212 |
+
slow_tokenizer_class is not None
|
| 1213 |
+
and fast_tokenizer_class is not None
|
| 1214 |
+
and issubclass(fast_tokenizer_class, PreTrainedTokenizerFast)
|
| 1215 |
+
and fast_tokenizer_class.slow_tokenizer_class != slow_tokenizer_class
|
| 1216 |
+
):
|
| 1217 |
+
raise ValueError(
|
| 1218 |
+
"The fast tokenizer class you are passing has a `slow_tokenizer_class` attribute that is not "
|
| 1219 |
+
"consistent with the slow tokenizer class you passed (fast tokenizer has "
|
| 1220 |
+
f"{fast_tokenizer_class.slow_tokenizer_class} and you passed {slow_tokenizer_class}. Fix one of those "
|
| 1221 |
+
"so they match!"
|
| 1222 |
+
)
|
| 1223 |
+
|
| 1224 |
+
# Avoid resetting a set slow/fast tokenizer if we are passing just the other ones.
|
| 1225 |
+
if config_class in TOKENIZER_MAPPING._extra_content:
|
| 1226 |
+
existing_slow, existing_fast = TOKENIZER_MAPPING[config_class]
|
| 1227 |
+
if slow_tokenizer_class is None:
|
| 1228 |
+
slow_tokenizer_class = existing_slow
|
| 1229 |
+
if fast_tokenizer_class is None:
|
| 1230 |
+
fast_tokenizer_class = existing_fast
|
| 1231 |
+
|
| 1232 |
+
TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class), exist_ok=exist_ok)
|
| 1233 |
+
|
| 1234 |
+
|
| 1235 |
+
__all__ = ["TOKENIZER_MAPPING", "AutoTokenizer"]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/auto/video_processing_auto.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""AutoVideoProcessor class."""
|
| 16 |
+
|
| 17 |
+
import importlib
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import warnings
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
from typing import TYPE_CHECKING, Optional, Union
|
| 23 |
+
|
| 24 |
+
# Build the list of all video processors
|
| 25 |
+
from ...configuration_utils import PretrainedConfig
|
| 26 |
+
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
| 27 |
+
from ...utils import CONFIG_NAME, VIDEO_PROCESSOR_NAME, cached_file, is_torchvision_available, logging
|
| 28 |
+
from ...utils.import_utils import requires
|
| 29 |
+
from ...video_processing_utils import BaseVideoProcessor
|
| 30 |
+
from .auto_factory import _LazyAutoMapping
|
| 31 |
+
from .configuration_auto import (
|
| 32 |
+
CONFIG_MAPPING_NAMES,
|
| 33 |
+
AutoConfig,
|
| 34 |
+
model_type_to_module_name,
|
| 35 |
+
replace_list_option_in_docstrings,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if TYPE_CHECKING:
|
| 43 |
+
# This significantly improves completion suggestion performance when
|
| 44 |
+
# the transformers package is used with Microsoft's Pylance language server.
|
| 45 |
+
VIDEO_PROCESSOR_MAPPING_NAMES: OrderedDict[str, tuple[Optional[str], Optional[str]]] = OrderedDict()
|
| 46 |
+
else:
|
| 47 |
+
VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
| 48 |
+
[
|
| 49 |
+
("glm4v", "Glm4vVideoProcessor"),
|
| 50 |
+
("instructblip", "InstructBlipVideoVideoProcessor"),
|
| 51 |
+
("instructblipvideo", "InstructBlipVideoVideoProcessor"),
|
| 52 |
+
("internvl", "InternVLVideoProcessor"),
|
| 53 |
+
("llava_next_video", "LlavaNextVideoVideoProcessor"),
|
| 54 |
+
("llava_onevision", "LlavaOnevisionVideoProcessor"),
|
| 55 |
+
("perception_lm", "PerceptionLMVideoProcessor"),
|
| 56 |
+
("qwen2_5_omni", "Qwen2VLVideoProcessor"),
|
| 57 |
+
("qwen2_5_vl", "Qwen2VLVideoProcessor"),
|
| 58 |
+
("qwen2_vl", "Qwen2VLVideoProcessor"),
|
| 59 |
+
("qwen3_omni_moe", "Qwen2VLVideoProcessor"),
|
| 60 |
+
("qwen3_vl", "Qwen3VLVideoProcessor"),
|
| 61 |
+
("qwen3_vl_moe", "Qwen3VLVideoProcessor"),
|
| 62 |
+
("sam2_video", "Sam2VideoVideoProcessor"),
|
| 63 |
+
("smolvlm", "SmolVLMVideoProcessor"),
|
| 64 |
+
("video_llava", "VideoLlavaVideoProcessor"),
|
| 65 |
+
("vjepa2", "VJEPA2VideoProcessor"),
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
for model_type, video_processors in VIDEO_PROCESSOR_MAPPING_NAMES.items():
|
| 70 |
+
fast_video_processor_class = video_processors
|
| 71 |
+
|
| 72 |
+
# If the torchvision is not available, we set it to None
|
| 73 |
+
if not is_torchvision_available():
|
| 74 |
+
fast_video_processor_class = None
|
| 75 |
+
|
| 76 |
+
VIDEO_PROCESSOR_MAPPING_NAMES[model_type] = fast_video_processor_class
|
| 77 |
+
|
| 78 |
+
VIDEO_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, VIDEO_PROCESSOR_MAPPING_NAMES)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def video_processor_class_from_name(class_name: str):
|
| 82 |
+
for module_name, extractors in VIDEO_PROCESSOR_MAPPING_NAMES.items():
|
| 83 |
+
if class_name in extractors:
|
| 84 |
+
module_name = model_type_to_module_name(module_name)
|
| 85 |
+
|
| 86 |
+
module = importlib.import_module(f".{module_name}", "transformers.models")
|
| 87 |
+
try:
|
| 88 |
+
return getattr(module, class_name)
|
| 89 |
+
except AttributeError:
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
for extractor in VIDEO_PROCESSOR_MAPPING._extra_content.values():
|
| 93 |
+
if getattr(extractor, "__name__", None) == class_name:
|
| 94 |
+
return extractor
|
| 95 |
+
|
| 96 |
+
# We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main
|
| 97 |
+
# init and we return the proper dummy to get an appropriate error message.
|
| 98 |
+
main_module = importlib.import_module("transformers")
|
| 99 |
+
if hasattr(main_module, class_name):
|
| 100 |
+
return getattr(main_module, class_name)
|
| 101 |
+
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_video_processor_config(
|
| 106 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 107 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 108 |
+
force_download: bool = False,
|
| 109 |
+
resume_download: Optional[bool] = None,
|
| 110 |
+
proxies: Optional[dict[str, str]] = None,
|
| 111 |
+
token: Optional[Union[bool, str]] = None,
|
| 112 |
+
revision: Optional[str] = None,
|
| 113 |
+
local_files_only: bool = False,
|
| 114 |
+
**kwargs,
|
| 115 |
+
):
|
| 116 |
+
"""
|
| 117 |
+
Loads the video processor configuration from a pretrained model video processor configuration.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 121 |
+
This can be either:
|
| 122 |
+
|
| 123 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
| 124 |
+
huggingface.co.
|
| 125 |
+
- a path to a *directory* containing a configuration file saved using the
|
| 126 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
| 127 |
+
|
| 128 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 129 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
| 130 |
+
cache should not be used.
|
| 131 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 132 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
| 133 |
+
exist.
|
| 134 |
+
resume_download:
|
| 135 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 136 |
+
Will be removed in v5 of Transformers.
|
| 137 |
+
proxies (`dict[str, str]`, *optional*):
|
| 138 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 139 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 140 |
+
token (`str` or *bool*, *optional*):
|
| 141 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 142 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 143 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 144 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 145 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 146 |
+
identifier allowed by git.
|
| 147 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 148 |
+
If `True`, will only try to load the video processor configuration from local files.
|
| 149 |
+
|
| 150 |
+
<Tip>
|
| 151 |
+
|
| 152 |
+
Passing `token=True` is required when you want to use a private model.
|
| 153 |
+
|
| 154 |
+
</Tip>
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
`Dict`: The configuration of the video processor.
|
| 158 |
+
|
| 159 |
+
Examples:
|
| 160 |
+
|
| 161 |
+
```python
|
| 162 |
+
# Download configuration from huggingface.co and cache.
|
| 163 |
+
video_processor_config = get_video_processor_config("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
| 164 |
+
# This model does not have a video processor config so the result will be an empty dict.
|
| 165 |
+
video_processor_config = get_video_processor_config("FacebookAI/xlm-roberta-base")
|
| 166 |
+
|
| 167 |
+
# Save a pretrained video processor locally and you can reload its config
|
| 168 |
+
from transformers import AutoVideoProcessor
|
| 169 |
+
|
| 170 |
+
video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
| 171 |
+
video_processor.save_pretrained("video-processor-test")
|
| 172 |
+
video_processor = get_video_processor_config("video-processor-test")
|
| 173 |
+
```"""
|
| 174 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 175 |
+
if use_auth_token is not None:
|
| 176 |
+
warnings.warn(
|
| 177 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 178 |
+
FutureWarning,
|
| 179 |
+
)
|
| 180 |
+
if token is not None:
|
| 181 |
+
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
| 182 |
+
token = use_auth_token
|
| 183 |
+
|
| 184 |
+
resolved_config_file = cached_file(
|
| 185 |
+
pretrained_model_name_or_path,
|
| 186 |
+
VIDEO_PROCESSOR_NAME,
|
| 187 |
+
cache_dir=cache_dir,
|
| 188 |
+
force_download=force_download,
|
| 189 |
+
resume_download=resume_download,
|
| 190 |
+
proxies=proxies,
|
| 191 |
+
token=token,
|
| 192 |
+
revision=revision,
|
| 193 |
+
local_files_only=local_files_only,
|
| 194 |
+
)
|
| 195 |
+
if resolved_config_file is None:
|
| 196 |
+
logger.info(
|
| 197 |
+
"Could not locate the video processor configuration file, will try to use the model config instead."
|
| 198 |
+
)
|
| 199 |
+
return {}
|
| 200 |
+
|
| 201 |
+
with open(resolved_config_file, encoding="utf-8") as reader:
|
| 202 |
+
return json.load(reader)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
@requires(backends=("vision", "torchvision"))
|
| 206 |
+
class AutoVideoProcessor:
|
| 207 |
+
r"""
|
| 208 |
+
This is a generic video processor class that will be instantiated as one of the video processor classes of the
|
| 209 |
+
library when created with the [`AutoVideoProcessor.from_pretrained`] class method.
|
| 210 |
+
|
| 211 |
+
This class cannot be instantiated directly using `__init__()` (throws an error).
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def __init__(self):
|
| 215 |
+
raise OSError(
|
| 216 |
+
"AutoVideoProcessor is designed to be instantiated "
|
| 217 |
+
"using the `AutoVideoProcessor.from_pretrained(pretrained_model_name_or_path)` method."
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
@classmethod
|
| 221 |
+
@replace_list_option_in_docstrings(VIDEO_PROCESSOR_MAPPING_NAMES)
|
| 222 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
| 223 |
+
r"""
|
| 224 |
+
Instantiate one of the video processor classes of the library from a pretrained model vocabulary.
|
| 225 |
+
|
| 226 |
+
The video processor class to instantiate is selected based on the `model_type` property of the config object
|
| 227 |
+
(either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
|
| 228 |
+
missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
|
| 229 |
+
|
| 230 |
+
List options
|
| 231 |
+
|
| 232 |
+
Params:
|
| 233 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 234 |
+
This can be either:
|
| 235 |
+
|
| 236 |
+
- a string, the *model id* of a pretrained video_processor hosted inside a model repo on
|
| 237 |
+
huggingface.co.
|
| 238 |
+
- a path to a *directory* containing a video processor file saved using the
|
| 239 |
+
[`~video_processing_utils.BaseVideoProcessor.save_pretrained`] method, e.g.,
|
| 240 |
+
`./my_model_directory/`.
|
| 241 |
+
- a path or url to a saved video processor JSON *file*, e.g.,
|
| 242 |
+
`./my_model_directory/preprocessor_config.json`.
|
| 243 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 244 |
+
Path to a directory in which a downloaded pretrained model video processor should be cached if the
|
| 245 |
+
standard cache should not be used.
|
| 246 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 247 |
+
Whether or not to force to (re-)download the video processor files and override the cached versions if
|
| 248 |
+
they exist.
|
| 249 |
+
resume_download:
|
| 250 |
+
Deprecated and ignored. All downloads are now resumed by default when possible.
|
| 251 |
+
Will be removed in v5 of Transformers.
|
| 252 |
+
proxies (`dict[str, str]`, *optional*):
|
| 253 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 254 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 255 |
+
token (`str` or *bool*, *optional*):
|
| 256 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 257 |
+
when running `hf auth login` (stored in `~/.huggingface`).
|
| 258 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 259 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 260 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 261 |
+
identifier allowed by git.
|
| 262 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 263 |
+
If `False`, then this function returns just the final video processor object. If `True`, then this
|
| 264 |
+
functions returns a `Tuple(video_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
|
| 265 |
+
consisting of the key/value pairs whose keys are not video processor attributes: i.e., the part of
|
| 266 |
+
`kwargs` which has not been used to update `video_processor` and is otherwise ignored.
|
| 267 |
+
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
| 268 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
| 269 |
+
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
| 270 |
+
execute code present on the Hub on your local machine.
|
| 271 |
+
kwargs (`dict[str, Any]`, *optional*):
|
| 272 |
+
The values in kwargs of any keys which are video processor attributes will be used to override the
|
| 273 |
+
loaded values. Behavior concerning key/value pairs whose keys are *not* video processor attributes is
|
| 274 |
+
controlled by the `return_unused_kwargs` keyword parameter.
|
| 275 |
+
|
| 276 |
+
<Tip>
|
| 277 |
+
|
| 278 |
+
Passing `token=True` is required when you want to use a private model.
|
| 279 |
+
|
| 280 |
+
</Tip>
|
| 281 |
+
|
| 282 |
+
Examples:
|
| 283 |
+
|
| 284 |
+
```python
|
| 285 |
+
>>> from transformers import AutoVideoProcessor
|
| 286 |
+
|
| 287 |
+
>>> # Download video processor from huggingface.co and cache.
|
| 288 |
+
>>> video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
| 289 |
+
|
| 290 |
+
>>> # If video processor files are in a directory (e.g. video processor was saved using *save_pretrained('./test/saved_model/')*)
|
| 291 |
+
>>> # video_processor = AutoVideoProcessor.from_pretrained("./test/saved_model/")
|
| 292 |
+
```"""
|
| 293 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 294 |
+
if use_auth_token is not None:
|
| 295 |
+
warnings.warn(
|
| 296 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
| 297 |
+
FutureWarning,
|
| 298 |
+
)
|
| 299 |
+
if kwargs.get("token") is not None:
|
| 300 |
+
raise ValueError(
|
| 301 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
| 302 |
+
)
|
| 303 |
+
kwargs["token"] = use_auth_token
|
| 304 |
+
|
| 305 |
+
config = kwargs.pop("config", None)
|
| 306 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 307 |
+
kwargs["_from_auto"] = True
|
| 308 |
+
|
| 309 |
+
config_dict, _ = BaseVideoProcessor.get_video_processor_dict(pretrained_model_name_or_path, **kwargs)
|
| 310 |
+
video_processor_class = config_dict.get("video_processor_type", None)
|
| 311 |
+
video_processor_auto_map = None
|
| 312 |
+
if "AutoVideoProcessor" in config_dict.get("auto_map", {}):
|
| 313 |
+
video_processor_auto_map = config_dict["auto_map"]["AutoVideoProcessor"]
|
| 314 |
+
|
| 315 |
+
# If we still don't have the video processor class, check if we're loading from a previous image processor config
|
| 316 |
+
# and if so, infer the video processor class from there.
|
| 317 |
+
if video_processor_class is None and video_processor_auto_map is None:
|
| 318 |
+
image_processor_class = config_dict.pop("image_processor_type", None)
|
| 319 |
+
if image_processor_class is not None:
|
| 320 |
+
video_processor_class_inferred = image_processor_class.replace("ImageProcessor", "VideoProcessor")
|
| 321 |
+
|
| 322 |
+
# Some models have different image processors, e.g. InternVL uses GotOCRImageProcessor
|
| 323 |
+
# We cannot use GotOCRVideoProcessor when falling back for BC and should try to infer from config later on
|
| 324 |
+
if video_processor_class_inferred in VIDEO_PROCESSOR_MAPPING_NAMES.values():
|
| 325 |
+
video_processor_class = video_processor_class_inferred
|
| 326 |
+
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
|
| 327 |
+
image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]
|
| 328 |
+
video_processor_auto_map = image_processor_auto_map.replace("ImageProcessor", "VideoProcessor")
|
| 329 |
+
|
| 330 |
+
# If we don't find the video processor class in the video processor config, let's try the model config.
|
| 331 |
+
if video_processor_class is None and video_processor_auto_map is None:
|
| 332 |
+
if not isinstance(config, PretrainedConfig):
|
| 333 |
+
config = AutoConfig.from_pretrained(
|
| 334 |
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
| 335 |
+
)
|
| 336 |
+
# It could be in `config.video_processor_type``
|
| 337 |
+
video_processor_class = getattr(config, "video_processor_type", None)
|
| 338 |
+
if hasattr(config, "auto_map") and "AutoVideoProcessor" in config.auto_map:
|
| 339 |
+
video_processor_auto_map = config.auto_map["AutoVideoProcessor"]
|
| 340 |
+
|
| 341 |
+
if video_processor_class is not None:
|
| 342 |
+
video_processor_class = video_processor_class_from_name(video_processor_class)
|
| 343 |
+
|
| 344 |
+
has_remote_code = video_processor_auto_map is not None
|
| 345 |
+
has_local_code = video_processor_class is not None or type(config) in VIDEO_PROCESSOR_MAPPING
|
| 346 |
+
trust_remote_code = resolve_trust_remote_code(
|
| 347 |
+
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
if has_remote_code and trust_remote_code:
|
| 351 |
+
class_ref = video_processor_auto_map
|
| 352 |
+
video_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
|
| 353 |
+
_ = kwargs.pop("code_revision", None)
|
| 354 |
+
video_processor_class.register_for_auto_class()
|
| 355 |
+
return video_processor_class.from_dict(config_dict, **kwargs)
|
| 356 |
+
elif video_processor_class is not None:
|
| 357 |
+
return video_processor_class.from_dict(config_dict, **kwargs)
|
| 358 |
+
# Last try: we use the VIDEO_PROCESSOR_MAPPING.
|
| 359 |
+
elif type(config) in VIDEO_PROCESSOR_MAPPING:
|
| 360 |
+
video_processor_class = VIDEO_PROCESSOR_MAPPING[type(config)]
|
| 361 |
+
|
| 362 |
+
if video_processor_class is not None:
|
| 363 |
+
return video_processor_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
| 364 |
+
else:
|
| 365 |
+
raise ValueError(
|
| 366 |
+
"This video processor cannot be instantiated. Please make sure you have `torchvision` installed."
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
raise ValueError(
|
| 370 |
+
f"Unrecognized video processor in {pretrained_model_name_or_path}. Should have a "
|
| 371 |
+
f"`video_processor_type` key in its {VIDEO_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
|
| 372 |
+
f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in VIDEO_PROCESSOR_MAPPING_NAMES)}"
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
@staticmethod
|
| 376 |
+
def register(
|
| 377 |
+
config_class,
|
| 378 |
+
video_processor_class,
|
| 379 |
+
exist_ok=False,
|
| 380 |
+
):
|
| 381 |
+
"""
|
| 382 |
+
Register a new video processor for this class.
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
config_class ([`PretrainedConfig`]):
|
| 386 |
+
The configuration corresponding to the model to register.
|
| 387 |
+
video_processor_class ([`BaseVideoProcessor`]):
|
| 388 |
+
The video processor to register.
|
| 389 |
+
"""
|
| 390 |
+
VIDEO_PROCESSOR_MAPPING.register(config_class, video_processor_class, exist_ok=exist_ok)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
__all__ = ["VIDEO_PROCESSOR_MAPPING", "AutoVideoProcessor"]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_bark import *
|
| 22 |
+
from .modeling_bark import *
|
| 23 |
+
from .processing_bark import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/configuration_bark.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""BARK model configuration"""
|
| 16 |
+
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
from ...configuration_utils import PretrainedConfig
|
| 20 |
+
from ...utils import add_start_docstrings, logging
|
| 21 |
+
from ..auto import CONFIG_MAPPING, AutoConfig
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
BARK_SUBMODELCONFIG_START_DOCSTRING = """
|
| 28 |
+
This is the configuration class to store the configuration of a [`{model}`]. It is used to instantiate the model
|
| 29 |
+
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 30 |
+
defaults will yield a similar configuration to that of the Bark [suno/bark](https://huggingface.co/suno/bark)
|
| 31 |
+
architecture.
|
| 32 |
+
|
| 33 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 34 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
block_size (`int`, *optional*, defaults to 1024):
|
| 38 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 39 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 40 |
+
input_vocab_size (`int`, *optional*, defaults to 10_048):
|
| 41 |
+
Vocabulary size of a Bark sub-model. Defines the number of different tokens that can be represented by the
|
| 42 |
+
`inputs_ids` passed when calling [`{model}`]. Defaults to 10_048 but should be carefully thought with
|
| 43 |
+
regards to the chosen sub-model.
|
| 44 |
+
output_vocab_size (`int`, *optional*, defaults to 10_048):
|
| 45 |
+
Output vocabulary size of a Bark sub-model. Defines the number of different tokens that can be represented
|
| 46 |
+
by the: `output_ids` when passing forward a [`{model}`]. Defaults to 10_048 but should be carefully thought
|
| 47 |
+
with regards to the chosen sub-model.
|
| 48 |
+
num_layers (`int`, *optional*, defaults to 12):
|
| 49 |
+
Number of hidden layers in the given sub-model.
|
| 50 |
+
num_heads (`int`, *optional*, defaults to 12):
|
| 51 |
+
Number of attention heads for each attention layer in the Transformer architecture.
|
| 52 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 53 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the architecture.
|
| 54 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 55 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 56 |
+
bias (`bool`, *optional*, defaults to `True`):
|
| 57 |
+
Whether or not to use bias in the linear layers and layer norm layers.
|
| 58 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 59 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 60 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 61 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class BarkSubModelConfig(PretrainedConfig):
|
| 66 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 67 |
+
|
| 68 |
+
attribute_map = {
|
| 69 |
+
"num_attention_heads": "num_heads",
|
| 70 |
+
"num_hidden_layers": "num_layers",
|
| 71 |
+
"vocab_size": "input_vocab_size",
|
| 72 |
+
"window_size": "block_size",
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
block_size=1024,
|
| 78 |
+
input_vocab_size=10_048,
|
| 79 |
+
output_vocab_size=10_048,
|
| 80 |
+
num_layers=12,
|
| 81 |
+
num_heads=12,
|
| 82 |
+
hidden_size=768,
|
| 83 |
+
dropout=0.0,
|
| 84 |
+
bias=True, # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
|
| 85 |
+
initializer_range=0.02,
|
| 86 |
+
use_cache=True,
|
| 87 |
+
**kwargs,
|
| 88 |
+
):
|
| 89 |
+
self.block_size = block_size
|
| 90 |
+
self.input_vocab_size = input_vocab_size
|
| 91 |
+
self.output_vocab_size = output_vocab_size
|
| 92 |
+
self.num_layers = num_layers
|
| 93 |
+
self.num_heads = num_heads
|
| 94 |
+
self.hidden_size = hidden_size
|
| 95 |
+
self.dropout = dropout
|
| 96 |
+
self.bias = bias
|
| 97 |
+
self.use_cache = use_cache
|
| 98 |
+
self.initializer_range = initializer_range
|
| 99 |
+
|
| 100 |
+
super().__init__(**kwargs)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@add_start_docstrings(
|
| 104 |
+
BARK_SUBMODELCONFIG_START_DOCSTRING.format(config="BarkSemanticConfig", model="BarkSemanticModel"),
|
| 105 |
+
"""
|
| 106 |
+
Example:
|
| 107 |
+
|
| 108 |
+
```python
|
| 109 |
+
>>> from transformers import BarkSemanticConfig, BarkSemanticModel
|
| 110 |
+
|
| 111 |
+
>>> # Initializing a Bark sub-module style configuration
|
| 112 |
+
>>> configuration = BarkSemanticConfig()
|
| 113 |
+
|
| 114 |
+
>>> # Initializing a model (with random weights) from the suno/bark style configuration
|
| 115 |
+
>>> model = BarkSemanticModel(configuration)
|
| 116 |
+
|
| 117 |
+
>>> # Accessing the model configuration
|
| 118 |
+
>>> configuration = model.config
|
| 119 |
+
```""",
|
| 120 |
+
)
|
| 121 |
+
class BarkSemanticConfig(BarkSubModelConfig):
|
| 122 |
+
model_type = "semantic"
|
| 123 |
+
base_config_key = "semantic_config"
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@add_start_docstrings(
|
| 127 |
+
BARK_SUBMODELCONFIG_START_DOCSTRING.format(config="BarkCoarseConfig", model="BarkCoarseModel"),
|
| 128 |
+
"""
|
| 129 |
+
Example:
|
| 130 |
+
|
| 131 |
+
```python
|
| 132 |
+
>>> from transformers import BarkCoarseConfig, BarkCoarseModel
|
| 133 |
+
|
| 134 |
+
>>> # Initializing a Bark sub-module style configuration
|
| 135 |
+
>>> configuration = BarkCoarseConfig()
|
| 136 |
+
|
| 137 |
+
>>> # Initializing a model (with random weights) from the suno/bark style configuration
|
| 138 |
+
>>> model = BarkCoarseModel(configuration)
|
| 139 |
+
|
| 140 |
+
>>> # Accessing the model configuration
|
| 141 |
+
>>> configuration = model.config
|
| 142 |
+
```""",
|
| 143 |
+
)
|
| 144 |
+
class BarkCoarseConfig(BarkSubModelConfig):
|
| 145 |
+
model_type = "coarse_acoustics"
|
| 146 |
+
base_config_key = "coarse_acoustics_config"
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@add_start_docstrings(
|
| 150 |
+
BARK_SUBMODELCONFIG_START_DOCSTRING.format(config="BarkFineConfig", model="BarkFineModel"),
|
| 151 |
+
"""
|
| 152 |
+
n_codes_total (`int`, *optional*, defaults to 8):
|
| 153 |
+
The total number of audio codebooks predicted. Used in the fine acoustics sub-model.
|
| 154 |
+
n_codes_given (`int`, *optional*, defaults to 1):
|
| 155 |
+
The number of audio codebooks predicted in the coarse acoustics sub-model. Used in the acoustics
|
| 156 |
+
sub-models.
|
| 157 |
+
Example:
|
| 158 |
+
|
| 159 |
+
```python
|
| 160 |
+
>>> from transformers import BarkFineConfig, BarkFineModel
|
| 161 |
+
|
| 162 |
+
>>> # Initializing a Bark sub-module style configuration
|
| 163 |
+
>>> configuration = BarkFineConfig()
|
| 164 |
+
|
| 165 |
+
>>> # Initializing a model (with random weights) from the suno/bark style configuration
|
| 166 |
+
>>> model = BarkFineModel(configuration)
|
| 167 |
+
|
| 168 |
+
>>> # Accessing the model configuration
|
| 169 |
+
>>> configuration = model.config
|
| 170 |
+
```""",
|
| 171 |
+
)
|
| 172 |
+
class BarkFineConfig(BarkSubModelConfig):
|
| 173 |
+
model_type = "fine_acoustics"
|
| 174 |
+
base_config_key = "fine_acoustics_config"
|
| 175 |
+
|
| 176 |
+
def __init__(self, tie_word_embeddings=True, n_codes_total=8, n_codes_given=1, **kwargs):
|
| 177 |
+
self.n_codes_total = n_codes_total
|
| 178 |
+
self.n_codes_given = n_codes_given
|
| 179 |
+
|
| 180 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class BarkConfig(PretrainedConfig):
|
| 184 |
+
"""
|
| 185 |
+
This is the configuration class to store the configuration of a [`BarkModel`]. It is used to instantiate a Bark
|
| 186 |
+
model according to the specified sub-models configurations, defining the model architecture.
|
| 187 |
+
|
| 188 |
+
Instantiating a configuration with the defaults will yield a similar configuration to that of the Bark
|
| 189 |
+
[suno/bark](https://huggingface.co/suno/bark) architecture.
|
| 190 |
+
|
| 191 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 192 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
semantic_config ([`BarkSemanticConfig`], *optional*):
|
| 196 |
+
Configuration of the underlying semantic sub-model.
|
| 197 |
+
coarse_acoustics_config ([`BarkCoarseConfig`], *optional*):
|
| 198 |
+
Configuration of the underlying coarse acoustics sub-model.
|
| 199 |
+
fine_acoustics_config ([`BarkFineConfig`], *optional*):
|
| 200 |
+
Configuration of the underlying fine acoustics sub-model.
|
| 201 |
+
codec_config ([`AutoConfig`], *optional*):
|
| 202 |
+
Configuration of the underlying codec sub-model.
|
| 203 |
+
|
| 204 |
+
Example:
|
| 205 |
+
|
| 206 |
+
```python
|
| 207 |
+
>>> from transformers import (
|
| 208 |
+
... BarkSemanticConfig,
|
| 209 |
+
... BarkCoarseConfig,
|
| 210 |
+
... BarkFineConfig,
|
| 211 |
+
... BarkModel,
|
| 212 |
+
... BarkConfig,
|
| 213 |
+
... AutoConfig,
|
| 214 |
+
... )
|
| 215 |
+
|
| 216 |
+
>>> # Initializing Bark sub-modules configurations.
|
| 217 |
+
>>> semantic_config = BarkSemanticConfig()
|
| 218 |
+
>>> coarse_acoustics_config = BarkCoarseConfig()
|
| 219 |
+
>>> fine_acoustics_config = BarkFineConfig()
|
| 220 |
+
>>> codec_config = AutoConfig.from_pretrained("facebook/encodec_24khz")
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
>>> # Initializing a Bark module style configuration
|
| 224 |
+
>>> configuration = BarkConfig.from_sub_model_configs(
|
| 225 |
+
... semantic_config, coarse_acoustics_config, fine_acoustics_config, codec_config
|
| 226 |
+
... )
|
| 227 |
+
|
| 228 |
+
>>> # Initializing a model (with random weights)
|
| 229 |
+
>>> model = BarkModel(configuration)
|
| 230 |
+
|
| 231 |
+
>>> # Accessing the model configuration
|
| 232 |
+
>>> configuration = model.config
|
| 233 |
+
```
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
model_type = "bark"
|
| 237 |
+
sub_configs = {
|
| 238 |
+
"semantic_config": BarkSemanticConfig,
|
| 239 |
+
"coarse_acoustics_config": BarkCoarseConfig,
|
| 240 |
+
"fine_acoustics_config": BarkFineConfig,
|
| 241 |
+
"codec_config": AutoConfig,
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
semantic_config: Optional[dict] = None,
|
| 247 |
+
coarse_acoustics_config: Optional[dict] = None,
|
| 248 |
+
fine_acoustics_config: Optional[dict] = None,
|
| 249 |
+
codec_config: Optional[dict] = None,
|
| 250 |
+
initializer_range=0.02,
|
| 251 |
+
**kwargs,
|
| 252 |
+
):
|
| 253 |
+
if semantic_config is None:
|
| 254 |
+
semantic_config = {}
|
| 255 |
+
logger.info("semantic_config is None. initializing the semantic model with default values.")
|
| 256 |
+
|
| 257 |
+
if coarse_acoustics_config is None:
|
| 258 |
+
coarse_acoustics_config = {}
|
| 259 |
+
logger.info("coarse_acoustics_config is None. initializing the coarse model with default values.")
|
| 260 |
+
|
| 261 |
+
if fine_acoustics_config is None:
|
| 262 |
+
fine_acoustics_config = {}
|
| 263 |
+
logger.info("fine_acoustics_config is None. initializing the fine model with default values.")
|
| 264 |
+
|
| 265 |
+
if codec_config is None:
|
| 266 |
+
codec_config = {}
|
| 267 |
+
logger.info("codec_config is None. initializing the codec model with default values.")
|
| 268 |
+
|
| 269 |
+
self.semantic_config = BarkSemanticConfig(**semantic_config)
|
| 270 |
+
self.coarse_acoustics_config = BarkCoarseConfig(**coarse_acoustics_config)
|
| 271 |
+
self.fine_acoustics_config = BarkFineConfig(**fine_acoustics_config)
|
| 272 |
+
codec_model_type = codec_config.get("model_type", "encodec")
|
| 273 |
+
self.codec_config = CONFIG_MAPPING[codec_model_type](**codec_config)
|
| 274 |
+
|
| 275 |
+
self.initializer_range = initializer_range
|
| 276 |
+
|
| 277 |
+
super().__init__(**kwargs)
|
| 278 |
+
|
| 279 |
+
@classmethod
|
| 280 |
+
def from_sub_model_configs(
|
| 281 |
+
cls,
|
| 282 |
+
semantic_config: BarkSemanticConfig,
|
| 283 |
+
coarse_acoustics_config: BarkCoarseConfig,
|
| 284 |
+
fine_acoustics_config: BarkFineConfig,
|
| 285 |
+
codec_config: PretrainedConfig,
|
| 286 |
+
**kwargs,
|
| 287 |
+
):
|
| 288 |
+
r"""
|
| 289 |
+
Instantiate a [`BarkConfig`] (or a derived class) from bark sub-models configuration.
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
[`BarkConfig`]: An instance of a configuration object
|
| 293 |
+
"""
|
| 294 |
+
return cls(
|
| 295 |
+
semantic_config=semantic_config.to_dict(),
|
| 296 |
+
coarse_acoustics_config=coarse_acoustics_config.to_dict(),
|
| 297 |
+
fine_acoustics_config=fine_acoustics_config.to_dict(),
|
| 298 |
+
codec_config=codec_config.to_dict(),
|
| 299 |
+
**kwargs,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
__all__ = ["BarkCoarseConfig", "BarkConfig", "BarkFineConfig", "BarkSemanticConfig"]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/generation_configuration_bark.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""BARK model generation configuration"""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
from ...generation.configuration_utils import GenerationConfig
|
| 21 |
+
from ...utils import logging
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class BarkSemanticGenerationConfig(GenerationConfig):
|
| 28 |
+
model_type = "semantic"
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
eos_token_id=10_000,
|
| 33 |
+
renormalize_logits=True,
|
| 34 |
+
max_new_tokens=768,
|
| 35 |
+
output_scores=False,
|
| 36 |
+
return_dict_in_generate=False,
|
| 37 |
+
output_hidden_states=False,
|
| 38 |
+
output_attentions=False,
|
| 39 |
+
temperature=1.0,
|
| 40 |
+
do_sample=False,
|
| 41 |
+
text_encoding_offset=10_048,
|
| 42 |
+
text_pad_token=129_595,
|
| 43 |
+
semantic_infer_token=129_599,
|
| 44 |
+
semantic_vocab_size=10_000,
|
| 45 |
+
max_input_semantic_length=256,
|
| 46 |
+
semantic_rate_hz=49.9,
|
| 47 |
+
min_eos_p=None,
|
| 48 |
+
**kwargs,
|
| 49 |
+
):
|
| 50 |
+
"""Class that holds a generation configuration for [`BarkSemanticModel`].
|
| 51 |
+
|
| 52 |
+
This configuration inherit from [`GenerationConfig`] and can be used to control the model generation. Read the
|
| 53 |
+
documentation from [`GenerationConfig`] for more information.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
eos_token_id (`int`, *optional*, defaults to 10_000):
|
| 57 |
+
The id of the *end-of-sequence* token.
|
| 58 |
+
renormalize_logits (`bool`, *optional*, defaults to `True`):
|
| 59 |
+
Whether to renormalize the logits after applying all the logits processors (including the
|
| 60 |
+
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
|
| 61 |
+
score logits are normalized but some logit processors break the normalization.
|
| 62 |
+
max_new_tokens (`int`, *optional*, defaults to 768):
|
| 63 |
+
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
| 64 |
+
output_scores (`bool`, *optional*, defaults to `False`):
|
| 65 |
+
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
| 66 |
+
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
| 67 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 68 |
+
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
| 69 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 70 |
+
for more details.
|
| 71 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
| 72 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 73 |
+
returned tensors for more details.
|
| 74 |
+
temperature (`float`, *optional*, defaults to 1.0):
|
| 75 |
+
The value used to modulate the next token probabilities.
|
| 76 |
+
do_sample (`bool`, *optional*, defaults to `False`):
|
| 77 |
+
Whether or not to use sampling ; use greedy decoding otherwise.
|
| 78 |
+
text_encoding_offset (`int`, *optional*, defaults to 10_048):
|
| 79 |
+
Text encoding offset.
|
| 80 |
+
text_pad_token (`int`, *optional*, defaults to 129_595):
|
| 81 |
+
Text pad token.
|
| 82 |
+
semantic_infer_token (`int`, *optional*, defaults to 129_599):
|
| 83 |
+
Semantic infer token.
|
| 84 |
+
semantic_vocab_size (`int`, *optional*, defaults to 10_000):
|
| 85 |
+
Semantic vocab size.
|
| 86 |
+
max_input_semantic_length (`int`, *optional*, defaults to 256):
|
| 87 |
+
Max length of semantic input vector.
|
| 88 |
+
semantic_rate_hz (`float`, *optional*, defaults to 49.9):
|
| 89 |
+
Semantic rate in Hertz.
|
| 90 |
+
min_eos_p (`float`, *optional*):
|
| 91 |
+
Minimum threshold of the probability of the EOS token for it to be sampled. This is an early stopping
|
| 92 |
+
strategy to mitigate potential unwanted generations at the end of a prompt. The original implementation
|
| 93 |
+
suggests a default value of 0.2.
|
| 94 |
+
"""
|
| 95 |
+
super().__init__(
|
| 96 |
+
temperature=temperature,
|
| 97 |
+
do_sample=do_sample,
|
| 98 |
+
eos_token_id=eos_token_id,
|
| 99 |
+
renormalize_logits=renormalize_logits,
|
| 100 |
+
max_new_tokens=max_new_tokens,
|
| 101 |
+
output_scores=output_scores,
|
| 102 |
+
return_dict_in_generate=return_dict_in_generate,
|
| 103 |
+
output_hidden_states=output_hidden_states,
|
| 104 |
+
output_attentions=output_attentions,
|
| 105 |
+
**kwargs,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self.text_encoding_offset = text_encoding_offset
|
| 109 |
+
self.text_pad_token = text_pad_token
|
| 110 |
+
self.semantic_pad_token = eos_token_id
|
| 111 |
+
self.semantic_infer_token = semantic_infer_token
|
| 112 |
+
self.semantic_vocab_size = semantic_vocab_size
|
| 113 |
+
self.max_input_semantic_length = max_input_semantic_length
|
| 114 |
+
self.semantic_rate_hz = semantic_rate_hz
|
| 115 |
+
self.min_eos_p = min_eos_p
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class BarkCoarseGenerationConfig(GenerationConfig):
|
| 119 |
+
model_type = "coarse_acoustics"
|
| 120 |
+
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
renormalize_logits=True,
|
| 124 |
+
output_scores=False,
|
| 125 |
+
return_dict_in_generate=False,
|
| 126 |
+
output_hidden_states=False,
|
| 127 |
+
output_attentions=False,
|
| 128 |
+
temperature=1.0,
|
| 129 |
+
do_sample=False,
|
| 130 |
+
coarse_semantic_pad_token=12_048,
|
| 131 |
+
coarse_rate_hz=75,
|
| 132 |
+
n_coarse_codebooks=2,
|
| 133 |
+
coarse_infer_token=12_050,
|
| 134 |
+
max_coarse_input_length=256,
|
| 135 |
+
max_coarse_history: int = 630,
|
| 136 |
+
sliding_window_len: int = 60,
|
| 137 |
+
**kwargs,
|
| 138 |
+
):
|
| 139 |
+
"""Class that holds a generation configuration for [`BarkCoarseModel`].
|
| 140 |
+
|
| 141 |
+
This configuration inherit from [`GenerationConfig`] and can be used to control the model generation. Read the
|
| 142 |
+
documentation from [`GenerationConfig`] for more information.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
renormalize_logits (`bool`, *optional*, defaults to `True`):
|
| 146 |
+
Whether to renormalize the logits after applying all the logits processors (including the
|
| 147 |
+
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
|
| 148 |
+
score logits are normalized but some logit processors break the normalization.
|
| 149 |
+
output_scores (`bool`, *optional*, defaults to `False`):
|
| 150 |
+
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
| 151 |
+
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
| 152 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 153 |
+
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
| 154 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 155 |
+
for more details.
|
| 156 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
| 157 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 158 |
+
returned tensors for more details.
|
| 159 |
+
temperature (`float`, *optional*, defaults to 1.0):
|
| 160 |
+
The value used to modulate the next token probabilities.
|
| 161 |
+
do_sample (`bool`, *optional*, defaults to `False`):
|
| 162 |
+
Whether or not to use sampling ; use greedy decoding otherwise.
|
| 163 |
+
coarse_semantic_pad_token (`int`, *optional*, defaults to 12_048):
|
| 164 |
+
Coarse semantic pad token.
|
| 165 |
+
coarse_rate_hz (`int`, *optional*, defaults to 75):
|
| 166 |
+
Coarse rate in Hertz.
|
| 167 |
+
n_coarse_codebooks (`int`, *optional*, defaults to 2):
|
| 168 |
+
Number of coarse codebooks.
|
| 169 |
+
coarse_infer_token (`int`, *optional*, defaults to 12_050):
|
| 170 |
+
Coarse infer token.
|
| 171 |
+
max_coarse_input_length (`int`, *optional*, defaults to 256):
|
| 172 |
+
Max length of input coarse vector.
|
| 173 |
+
max_coarse_history (`int`, *optional*, defaults to 630):
|
| 174 |
+
Max length of the output of the coarse acoustics model used in the fine generation step.
|
| 175 |
+
sliding_window_len (`int`, *optional*, defaults to 60):
|
| 176 |
+
The coarse generation step uses a sliding window to generate raw audio.
|
| 177 |
+
"""
|
| 178 |
+
super().__init__(
|
| 179 |
+
temperature=temperature,
|
| 180 |
+
do_sample=do_sample,
|
| 181 |
+
renormalize_logits=renormalize_logits,
|
| 182 |
+
output_scores=output_scores,
|
| 183 |
+
return_dict_in_generate=return_dict_in_generate,
|
| 184 |
+
output_hidden_states=output_hidden_states,
|
| 185 |
+
output_attentions=output_attentions,
|
| 186 |
+
**kwargs,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
self.coarse_semantic_pad_token = coarse_semantic_pad_token
|
| 190 |
+
self.coarse_rate_hz = coarse_rate_hz
|
| 191 |
+
self.n_coarse_codebooks = n_coarse_codebooks
|
| 192 |
+
self.coarse_infer_token = coarse_infer_token
|
| 193 |
+
self.max_coarse_input_length = max_coarse_input_length
|
| 194 |
+
self.max_coarse_history = max_coarse_history
|
| 195 |
+
self.sliding_window_len = sliding_window_len
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class BarkFineGenerationConfig(GenerationConfig):
|
| 199 |
+
model_type = "fine_acoustics"
|
| 200 |
+
|
| 201 |
+
def __init__(
|
| 202 |
+
self,
|
| 203 |
+
temperature=1.0,
|
| 204 |
+
max_fine_history_length=512,
|
| 205 |
+
max_fine_input_length=1024,
|
| 206 |
+
n_fine_codebooks=8,
|
| 207 |
+
**kwargs,
|
| 208 |
+
):
|
| 209 |
+
"""Class that holds a generation configuration for [`BarkFineModel`].
|
| 210 |
+
|
| 211 |
+
[`BarkFineModel`] is an autoencoder model, so should not usually be used for generation. However, under the
|
| 212 |
+
hood, it uses `temperature` when used by [`BarkModel`]
|
| 213 |
+
|
| 214 |
+
This configuration inherit from [`GenerationConfig`] and can be used to control the model generation. Read the
|
| 215 |
+
documentation from [`GenerationConfig`] for more information.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
temperature (`float`, *optional*):
|
| 219 |
+
The value used to modulate the next token probabilities.
|
| 220 |
+
max_fine_history_length (`int`, *optional*, defaults to 512):
|
| 221 |
+
Max length of the fine history vector.
|
| 222 |
+
max_fine_input_length (`int`, *optional*, defaults to 1024):
|
| 223 |
+
Max length of fine input vector.
|
| 224 |
+
n_fine_codebooks (`int`, *optional*, defaults to 8):
|
| 225 |
+
Number of codebooks used.
|
| 226 |
+
"""
|
| 227 |
+
super().__init__(temperature=temperature)
|
| 228 |
+
|
| 229 |
+
self.max_fine_history_length = max_fine_history_length
|
| 230 |
+
self.max_fine_input_length = max_fine_input_length
|
| 231 |
+
self.n_fine_codebooks = n_fine_codebooks
|
| 232 |
+
|
| 233 |
+
def validate(self, **kwargs):
|
| 234 |
+
"""
|
| 235 |
+
Overrides GenerationConfig.validate because BarkFineGenerationConfig don't use any parameters outside
|
| 236 |
+
temperature.
|
| 237 |
+
"""
|
| 238 |
+
pass
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class BarkGenerationConfig(GenerationConfig):
|
| 242 |
+
model_type = "bark"
|
| 243 |
+
|
| 244 |
+
# TODO (joao): nested from_dict
|
| 245 |
+
|
| 246 |
+
def __init__(
|
| 247 |
+
self,
|
| 248 |
+
semantic_config: Optional[dict] = None,
|
| 249 |
+
coarse_acoustics_config: Optional[dict] = None,
|
| 250 |
+
fine_acoustics_config: Optional[dict] = None,
|
| 251 |
+
sample_rate=24_000,
|
| 252 |
+
codebook_size=1024,
|
| 253 |
+
**kwargs,
|
| 254 |
+
):
|
| 255 |
+
"""Class that holds a generation configuration for [`BarkModel`].
|
| 256 |
+
|
| 257 |
+
The [`BarkModel`] does not have a `generate` method, but uses this class to generate speeches with a nested
|
| 258 |
+
[`BarkGenerationConfig`] which uses [`BarkSemanticGenerationConfig`], [`BarkCoarseGenerationConfig`],
|
| 259 |
+
[`BarkFineGenerationConfig`].
|
| 260 |
+
|
| 261 |
+
This configuration inherit from [`GenerationConfig`] and can be used to control the model generation. Read the
|
| 262 |
+
documentation from [`GenerationConfig`] for more information.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
semantic_config (`Dict`, *optional*):
|
| 266 |
+
Semantic generation configuration.
|
| 267 |
+
coarse_acoustics_config (`Dict`, *optional*):
|
| 268 |
+
Coarse generation configuration.
|
| 269 |
+
fine_acoustics_config (`Dict`, *optional*):
|
| 270 |
+
Fine generation configuration.
|
| 271 |
+
sample_rate (`int`, *optional*, defaults to 24_000):
|
| 272 |
+
Sample rate.
|
| 273 |
+
codebook_size (`int`, *optional*, defaults to 1024):
|
| 274 |
+
Vector length for each codebook.
|
| 275 |
+
"""
|
| 276 |
+
if semantic_config is None:
|
| 277 |
+
semantic_config = {}
|
| 278 |
+
logger.info("semantic_config is None. initializing the semantic model with default values.")
|
| 279 |
+
|
| 280 |
+
if coarse_acoustics_config is None:
|
| 281 |
+
coarse_acoustics_config = {}
|
| 282 |
+
logger.info("coarse_acoustics_config is None. initializing the coarse model with default values.")
|
| 283 |
+
|
| 284 |
+
if fine_acoustics_config is None:
|
| 285 |
+
fine_acoustics_config = {}
|
| 286 |
+
logger.info("fine_acoustics_config is None. initializing the fine model with default values.")
|
| 287 |
+
|
| 288 |
+
self.semantic_config = BarkSemanticGenerationConfig(**semantic_config)
|
| 289 |
+
self.coarse_acoustics_config = BarkCoarseGenerationConfig(**coarse_acoustics_config)
|
| 290 |
+
self.fine_acoustics_config = BarkFineGenerationConfig(**fine_acoustics_config)
|
| 291 |
+
|
| 292 |
+
self.sample_rate = sample_rate
|
| 293 |
+
self.codebook_size = codebook_size
|
| 294 |
+
|
| 295 |
+
@classmethod
|
| 296 |
+
def from_sub_model_configs(
|
| 297 |
+
cls,
|
| 298 |
+
semantic_config: BarkSemanticGenerationConfig,
|
| 299 |
+
coarse_acoustics_config: BarkCoarseGenerationConfig,
|
| 300 |
+
fine_acoustics_config: BarkFineGenerationConfig,
|
| 301 |
+
**kwargs,
|
| 302 |
+
):
|
| 303 |
+
r"""
|
| 304 |
+
Instantiate a [`BarkGenerationConfig`] (or a derived class) from bark sub-models generation configuration.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
[`BarkGenerationConfig`]: An instance of a configuration object
|
| 308 |
+
"""
|
| 309 |
+
return cls(
|
| 310 |
+
semantic_config=semantic_config.to_dict(),
|
| 311 |
+
coarse_acoustics_config=coarse_acoustics_config.to_dict(),
|
| 312 |
+
fine_acoustics_config=fine_acoustics_config.to_dict(),
|
| 313 |
+
**kwargs,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
def to_dict(self):
|
| 317 |
+
"""
|
| 318 |
+
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
`dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
| 322 |
+
"""
|
| 323 |
+
output = copy.deepcopy(self.__dict__)
|
| 324 |
+
|
| 325 |
+
output["semantic_config"] = self.semantic_config.to_dict()
|
| 326 |
+
output["coarse_acoustics_config"] = self.coarse_acoustics_config.to_dict()
|
| 327 |
+
output["fine_acoustics_config"] = self.fine_acoustics_config.to_dict()
|
| 328 |
+
|
| 329 |
+
output["model_type"] = self.__class__.model_type
|
| 330 |
+
return output
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/modeling_bark.py
ADDED
|
@@ -0,0 +1,1628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch BARK model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import warnings
|
| 19 |
+
from typing import Optional, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from torch import nn
|
| 24 |
+
from torch.nn import functional as F
|
| 25 |
+
|
| 26 |
+
from ...cache_utils import Cache, DynamicCache
|
| 27 |
+
from ...generation import GenerationMixin
|
| 28 |
+
from ...generation.logits_process import (
|
| 29 |
+
AlternatingCodebooksLogitsProcessor,
|
| 30 |
+
BarkEosPrioritizerLogitsProcessor,
|
| 31 |
+
SuppressTokensLogitsProcessor,
|
| 32 |
+
)
|
| 33 |
+
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
| 34 |
+
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
| 35 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 36 |
+
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
|
| 37 |
+
from ...modeling_utils import PreTrainedModel, get_parameter_device
|
| 38 |
+
from ...utils import (
|
| 39 |
+
auto_docstring,
|
| 40 |
+
is_accelerate_available,
|
| 41 |
+
is_torch_accelerator_available,
|
| 42 |
+
logging,
|
| 43 |
+
)
|
| 44 |
+
from ..auto import AutoModel
|
| 45 |
+
from .configuration_bark import (
|
| 46 |
+
BarkCoarseConfig,
|
| 47 |
+
BarkConfig,
|
| 48 |
+
BarkFineConfig,
|
| 49 |
+
BarkSemanticConfig,
|
| 50 |
+
BarkSubModelConfig,
|
| 51 |
+
)
|
| 52 |
+
from .generation_configuration_bark import (
|
| 53 |
+
BarkCoarseGenerationConfig,
|
| 54 |
+
BarkFineGenerationConfig,
|
| 55 |
+
BarkSemanticGenerationConfig,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if is_flash_attn_available():
|
| 60 |
+
from ...modeling_flash_attention_utils import _flash_attention_forward
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
logger = logging.get_logger(__name__)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class BarkSelfAttention(nn.Module):
|
| 67 |
+
# adapted from GPTNeoSelfAttention and Bark code
|
| 68 |
+
# BarkSelfAttention can have two attention type, i.e full attention or causal attention
|
| 69 |
+
|
| 70 |
+
def __init__(self, config, is_causal=False, layer_idx=None):
|
| 71 |
+
super().__init__()
|
| 72 |
+
|
| 73 |
+
# regularization
|
| 74 |
+
self.dropout = config.dropout
|
| 75 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 76 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 77 |
+
|
| 78 |
+
self.embed_dim = config.hidden_size
|
| 79 |
+
self.num_heads = config.num_heads
|
| 80 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 81 |
+
|
| 82 |
+
if config.hidden_size % config.num_heads != 0:
|
| 83 |
+
raise ValueError(
|
| 84 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 85 |
+
f" {self.num_heads})."
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# key, query, value projections for all heads, but in a batch
|
| 89 |
+
self.att_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.bias)
|
| 90 |
+
# output projection
|
| 91 |
+
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.bias)
|
| 92 |
+
|
| 93 |
+
self.is_causal = is_causal
|
| 94 |
+
self.layer_idx = layer_idx
|
| 95 |
+
if is_causal:
|
| 96 |
+
block_size = config.block_size
|
| 97 |
+
bias = torch.tril(torch.ones((block_size, block_size), dtype=bool)).view(1, 1, block_size, block_size)
|
| 98 |
+
self.register_buffer("bias", bias)
|
| 99 |
+
|
| 100 |
+
# Copied from transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoSelfAttention._split_heads
|
| 101 |
+
def _split_heads(self, tensor, num_heads, attn_head_size):
|
| 102 |
+
"""
|
| 103 |
+
Splits hidden_size dim into attn_head_size and num_heads
|
| 104 |
+
"""
|
| 105 |
+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
| 106 |
+
tensor = tensor.view(new_shape)
|
| 107 |
+
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
| 108 |
+
|
| 109 |
+
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
| 110 |
+
"""
|
| 111 |
+
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
# re-assemble all head outputs side by side
|
| 115 |
+
# (batch, num_heads, seq_len, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
|
| 116 |
+
tensor = tensor.transpose(1, 2).contiguous()
|
| 117 |
+
tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
|
| 118 |
+
|
| 119 |
+
return tensor
|
| 120 |
+
|
| 121 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
| 122 |
+
# unlike GPTNeo's SelfAttention, divide by the square root of the dimension of the query and the key
|
| 123 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * (1.0 / math.sqrt(self.head_dim))
|
| 124 |
+
|
| 125 |
+
if self.is_causal:
|
| 126 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
| 127 |
+
|
| 128 |
+
# fill the upper left part of the attention weights with inf
|
| 129 |
+
attn_weights = attn_weights.masked_fill(
|
| 130 |
+
self.bias[:, :, key_length - query_length : key_length, :key_length] == 0,
|
| 131 |
+
torch.finfo(attn_weights.dtype).min,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
if attention_mask is not None:
|
| 135 |
+
# Apply the attention mask
|
| 136 |
+
attn_weights = attn_weights + attention_mask
|
| 137 |
+
|
| 138 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 139 |
+
attn_weights = attn_weights.to(value.dtype)
|
| 140 |
+
attn_weights = self.attn_dropout(attn_weights)
|
| 141 |
+
|
| 142 |
+
# Mask heads if we want to
|
| 143 |
+
if head_mask is not None:
|
| 144 |
+
attn_weights = attn_weights * head_mask
|
| 145 |
+
|
| 146 |
+
# (batch, num_heads, seq_len, seq_len) x (batch, num_heads, seq_len, attn_head_size)
|
| 147 |
+
# -> (batch, num_heads, seq_len, attn_head_size)
|
| 148 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 149 |
+
|
| 150 |
+
return attn_output, attn_weights
|
| 151 |
+
|
| 152 |
+
def forward(
|
| 153 |
+
self,
|
| 154 |
+
hidden_states,
|
| 155 |
+
attention_mask=None,
|
| 156 |
+
past_key_values=None,
|
| 157 |
+
head_mask=None,
|
| 158 |
+
use_cache=False,
|
| 159 |
+
output_attentions=False,
|
| 160 |
+
cache_position=None,
|
| 161 |
+
):
|
| 162 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
| 163 |
+
query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
|
| 164 |
+
|
| 165 |
+
query = self._split_heads(query, self.num_heads, self.head_dim)
|
| 166 |
+
key = self._split_heads(key, self.num_heads, self.head_dim)
|
| 167 |
+
value = self._split_heads(value, self.num_heads, self.head_dim)
|
| 168 |
+
|
| 169 |
+
if past_key_values is not None:
|
| 170 |
+
key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position})
|
| 171 |
+
|
| 172 |
+
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
| 173 |
+
|
| 174 |
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
| 175 |
+
attn_output = self.out_proj(attn_output)
|
| 176 |
+
attn_output = self.resid_dropout(attn_output)
|
| 177 |
+
|
| 178 |
+
return attn_output, attn_weights
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class BarkSelfFlashAttention2(BarkSelfAttention):
|
| 182 |
+
"""
|
| 183 |
+
Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays
|
| 184 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 185 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
def __init__(self, *args, **kwargs):
|
| 189 |
+
super().__init__(*args, **kwargs)
|
| 190 |
+
|
| 191 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 192 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 193 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 194 |
+
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
| 195 |
+
|
| 196 |
+
def _split_heads(self, tensor, num_heads, attn_head_size):
|
| 197 |
+
"""
|
| 198 |
+
Splits hidden_size dim into attn_head_size and num_heads
|
| 199 |
+
"""
|
| 200 |
+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
| 201 |
+
tensor = tensor.view(new_shape)
|
| 202 |
+
# Flash attention requires the input to have the shape
|
| 203 |
+
# batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features)
|
| 204 |
+
return tensor
|
| 205 |
+
|
| 206 |
+
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
| 207 |
+
"""
|
| 208 |
+
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
| 209 |
+
"""
|
| 210 |
+
# re-assemble all head outputs side by side
|
| 211 |
+
# (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
|
| 212 |
+
tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
|
| 213 |
+
return tensor
|
| 214 |
+
|
| 215 |
+
def forward(
|
| 216 |
+
self,
|
| 217 |
+
hidden_states,
|
| 218 |
+
attention_mask=None,
|
| 219 |
+
past_key_values=None,
|
| 220 |
+
head_mask=None,
|
| 221 |
+
use_cache=False,
|
| 222 |
+
output_attentions=False,
|
| 223 |
+
cache_position=None,
|
| 224 |
+
):
|
| 225 |
+
batch_size, query_len, _ = hidden_states.size()
|
| 226 |
+
|
| 227 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
| 228 |
+
query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
|
| 229 |
+
|
| 230 |
+
query = self._split_heads(query, self.num_heads, self.head_dim)
|
| 231 |
+
key = self._split_heads(key, self.num_heads, self.head_dim)
|
| 232 |
+
value = self._split_heads(value, self.num_heads, self.head_dim)
|
| 233 |
+
|
| 234 |
+
if past_key_values is not None:
|
| 235 |
+
key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position})
|
| 236 |
+
|
| 237 |
+
attn_output = _flash_attention_forward(
|
| 238 |
+
query,
|
| 239 |
+
key,
|
| 240 |
+
value,
|
| 241 |
+
attention_mask,
|
| 242 |
+
query_len,
|
| 243 |
+
dropout=self.dropout if self.training else 0.0,
|
| 244 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 245 |
+
is_causal=self.is_causal,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
| 249 |
+
attn_output = self.out_proj(attn_output)
|
| 250 |
+
attn_output = self.resid_dropout(attn_output)
|
| 251 |
+
|
| 252 |
+
return attn_output, None
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
BARK_ATTENTION_CLASSES = {
|
| 256 |
+
"eager": BarkSelfAttention,
|
| 257 |
+
"flash_attention_2": BarkSelfFlashAttention2,
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class BarkMLP(nn.Module):
|
| 262 |
+
def __init__(self, config):
|
| 263 |
+
super().__init__()
|
| 264 |
+
self.in_proj = nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=config.bias)
|
| 265 |
+
self.out_proj = nn.Linear(4 * config.hidden_size, config.hidden_size, bias=config.bias)
|
| 266 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 267 |
+
self.gelu = nn.GELU()
|
| 268 |
+
|
| 269 |
+
def forward(self, hidden_states):
|
| 270 |
+
hidden_states = self.in_proj(hidden_states)
|
| 271 |
+
hidden_states = self.gelu(hidden_states)
|
| 272 |
+
hidden_states = self.out_proj(hidden_states)
|
| 273 |
+
hidden_states = self.dropout(hidden_states)
|
| 274 |
+
return hidden_states
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class BarkBlock(GradientCheckpointingLayer):
|
| 278 |
+
def __init__(self, config, is_causal=False, layer_idx=None):
|
| 279 |
+
super().__init__()
|
| 280 |
+
|
| 281 |
+
if is_causal:
|
| 282 |
+
# if causal, the layerNorm bias is optional to stick with Bark choice of leaving optional bias
|
| 283 |
+
# in AutoRegressive models (corresponding to the "Text" and the "Coarse" modules)
|
| 284 |
+
self.layernorm_1 = nn.LayerNorm(config.hidden_size, bias=config.bias)
|
| 285 |
+
self.layernorm_2 = nn.LayerNorm(config.hidden_size, bias=config.bias)
|
| 286 |
+
else:
|
| 287 |
+
self.layernorm_1 = nn.LayerNorm(config.hidden_size)
|
| 288 |
+
self.layernorm_2 = nn.LayerNorm(config.hidden_size)
|
| 289 |
+
|
| 290 |
+
self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation](
|
| 291 |
+
config, is_causal=is_causal, layer_idx=layer_idx
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
self.mlp = BarkMLP(config)
|
| 295 |
+
|
| 296 |
+
def forward(
|
| 297 |
+
self,
|
| 298 |
+
hidden_states,
|
| 299 |
+
past_key_values=None,
|
| 300 |
+
attention_mask=None,
|
| 301 |
+
head_mask=None,
|
| 302 |
+
use_cache=False,
|
| 303 |
+
output_attentions=False,
|
| 304 |
+
cache_position=None,
|
| 305 |
+
):
|
| 306 |
+
intermediary_hidden_states = self.layernorm_1(hidden_states)
|
| 307 |
+
|
| 308 |
+
attn_outputs = self.attn(
|
| 309 |
+
intermediary_hidden_states,
|
| 310 |
+
past_key_values=past_key_values,
|
| 311 |
+
attention_mask=attention_mask,
|
| 312 |
+
head_mask=head_mask,
|
| 313 |
+
use_cache=use_cache,
|
| 314 |
+
output_attentions=output_attentions,
|
| 315 |
+
cache_position=cache_position,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
attn_output = attn_outputs[0] # output_attn: output, present_key_values, (attn_weights)
|
| 319 |
+
outputs = attn_outputs[1:]
|
| 320 |
+
|
| 321 |
+
intermediary_hidden_states = hidden_states + attn_output
|
| 322 |
+
intermediary_hidden_states = intermediary_hidden_states + self.mlp(
|
| 323 |
+
self.layernorm_2(intermediary_hidden_states)
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
return (intermediary_hidden_states,) + outputs
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@auto_docstring
|
| 330 |
+
class BarkPreTrainedModel(PreTrainedModel):
|
| 331 |
+
config: BarkConfig
|
| 332 |
+
supports_gradient_checkpointing = False
|
| 333 |
+
_supports_flash_attn = True
|
| 334 |
+
|
| 335 |
+
def _init_weights(self, module):
|
| 336 |
+
"""Initialize the weights."""
|
| 337 |
+
if isinstance(module, (nn.Linear,)):
|
| 338 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 339 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 340 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 341 |
+
if module.bias is not None:
|
| 342 |
+
module.bias.data.zero_()
|
| 343 |
+
elif isinstance(module, nn.Embedding):
|
| 344 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 345 |
+
if module.padding_idx is not None:
|
| 346 |
+
module.weight.data[module.padding_idx].zero_()
|
| 347 |
+
elif isinstance(module, nn.LayerNorm):
|
| 348 |
+
module.bias.data.zero_()
|
| 349 |
+
module.weight.data.fill_(1.0)
|
| 350 |
+
|
| 351 |
+
def __init__(self, *inputs, **kwargs):
|
| 352 |
+
super().__init__(*inputs, **kwargs)
|
| 353 |
+
|
| 354 |
+
@property
|
| 355 |
+
def device(self) -> torch.device:
|
| 356 |
+
"""
|
| 357 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
| 358 |
+
device).
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
# if has _hf_hook, has been offloaded so the device has to be found in the hook
|
| 362 |
+
if not hasattr(self, "_hf_hook"):
|
| 363 |
+
return get_parameter_device(self)
|
| 364 |
+
for module in self.modules():
|
| 365 |
+
if (
|
| 366 |
+
hasattr(module, "_hf_hook")
|
| 367 |
+
and hasattr(module._hf_hook, "execution_device")
|
| 368 |
+
and module._hf_hook.execution_device is not None
|
| 369 |
+
):
|
| 370 |
+
return torch.device(module._hf_hook.execution_device)
|
| 371 |
+
|
| 372 |
+
return get_parameter_device(self)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
# GPT2-like autoregressive model
|
| 376 |
+
class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
|
| 377 |
+
config: BarkSubModelConfig
|
| 378 |
+
|
| 379 |
+
def __init__(self, config):
|
| 380 |
+
super().__init__(config)
|
| 381 |
+
self.config = config
|
| 382 |
+
|
| 383 |
+
# initialize as an autoregressive GPT-like model
|
| 384 |
+
self.input_embeds_layer = nn.Embedding(config.input_vocab_size, config.hidden_size)
|
| 385 |
+
self.position_embeds_layer = nn.Embedding(config.block_size, config.hidden_size)
|
| 386 |
+
|
| 387 |
+
self.drop = nn.Dropout(config.dropout)
|
| 388 |
+
|
| 389 |
+
self.layers = nn.ModuleList([BarkBlock(config, is_causal=True, layer_idx=i) for i in range(config.num_layers)])
|
| 390 |
+
|
| 391 |
+
self.layernorm_final = nn.LayerNorm(config.hidden_size, bias=config.bias)
|
| 392 |
+
|
| 393 |
+
self.lm_head = nn.Linear(config.hidden_size, config.output_vocab_size, bias=False)
|
| 394 |
+
self.gradient_checkpointing = False
|
| 395 |
+
|
| 396 |
+
# Initialize weights and apply final processing
|
| 397 |
+
self.post_init()
|
| 398 |
+
|
| 399 |
+
def get_output_embeddings(self):
|
| 400 |
+
# NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
|
| 401 |
+
# See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
|
| 402 |
+
return None
|
| 403 |
+
|
| 404 |
+
def get_input_embeddings(self):
|
| 405 |
+
return self.input_embeds_layer
|
| 406 |
+
|
| 407 |
+
def set_input_embeddings(self, new_embeddings):
|
| 408 |
+
self.input_embeds_layer = new_embeddings
|
| 409 |
+
|
| 410 |
+
def prepare_inputs_for_generation(
|
| 411 |
+
self,
|
| 412 |
+
input_ids,
|
| 413 |
+
attention_mask=None,
|
| 414 |
+
input_embeds=None,
|
| 415 |
+
past_key_values=None,
|
| 416 |
+
position_ids=None,
|
| 417 |
+
use_cache=None,
|
| 418 |
+
cache_position=None,
|
| 419 |
+
**kwargs,
|
| 420 |
+
):
|
| 421 |
+
# Overwritten -- bark uses `input_embeds` not `inputS_embeds`
|
| 422 |
+
|
| 423 |
+
model_inputs = super().prepare_inputs_for_generation(
|
| 424 |
+
input_ids,
|
| 425 |
+
attention_mask=attention_mask,
|
| 426 |
+
inputs_embeds=input_embeds,
|
| 427 |
+
past_key_values=past_key_values,
|
| 428 |
+
position_ids=position_ids,
|
| 429 |
+
use_cache=use_cache,
|
| 430 |
+
cache_position=cache_position,
|
| 431 |
+
**kwargs,
|
| 432 |
+
)
|
| 433 |
+
model_inputs["input_embeds"] = model_inputs.pop("inputs_embeds", None)
|
| 434 |
+
return model_inputs
|
| 435 |
+
|
| 436 |
+
@auto_docstring
|
| 437 |
+
def forward(
|
| 438 |
+
self,
|
| 439 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 440 |
+
past_key_values: Optional[Cache] = None,
|
| 441 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 442 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 443 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 444 |
+
labels: Optional[torch.LongTensor] = None,
|
| 445 |
+
input_embeds: Optional[torch.Tensor] = None,
|
| 446 |
+
use_cache: Optional[bool] = None,
|
| 447 |
+
output_attentions: Optional[bool] = None,
|
| 448 |
+
output_hidden_states: Optional[bool] = None,
|
| 449 |
+
return_dict: Optional[bool] = None,
|
| 450 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 451 |
+
) -> Union[tuple[torch.Tensor], CausalLMOutputWithPast]:
|
| 452 |
+
r"""
|
| 453 |
+
input_embeds (`torch.FloatTensor` of shape `(batch_size, input_sequence_length, hidden_size)`, *optional*):
|
| 454 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
| 455 |
+
Here, due to `Bark` particularities, if `past_key_values` is used, `input_embeds` will be ignored and you
|
| 456 |
+
have to use `input_ids`. If `past_key_values` is not used and `use_cache` is set to `True`, `input_embeds`
|
| 457 |
+
is used in priority instead of `input_ids`.
|
| 458 |
+
"""
|
| 459 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 460 |
+
output_hidden_states = (
|
| 461 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 462 |
+
)
|
| 463 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 464 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 465 |
+
|
| 466 |
+
loss = None
|
| 467 |
+
if labels is not None:
|
| 468 |
+
raise NotImplementedError(
|
| 469 |
+
"Training is not implemented yet for Bark - ensure you do not pass `labels` to the model."
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
# Verify if input_embeds already exists
|
| 473 |
+
# then compute embeddings.
|
| 474 |
+
if input_ids is not None and input_embeds is not None:
|
| 475 |
+
raise ValueError("You cannot specify both input_ids and input_embeds at the same time")
|
| 476 |
+
elif input_embeds is not None and past_key_values is None:
|
| 477 |
+
# we want to return the input_embeds in priority so that it is in line with a weird hack
|
| 478 |
+
# of Bark which concatenate two bits of the input_embeds on the first forward pass of the semantic model
|
| 479 |
+
pass
|
| 480 |
+
elif input_ids is not None:
|
| 481 |
+
input_embeds = self.input_embeds_layer(input_ids) # token embeddings of shape (b, t, n_embd)
|
| 482 |
+
elif input_embeds is not None:
|
| 483 |
+
pass
|
| 484 |
+
else:
|
| 485 |
+
raise ValueError("You have to specify either input_ids or input_embeds")
|
| 486 |
+
|
| 487 |
+
input_shape = input_embeds.size()[:-1]
|
| 488 |
+
batch_size = input_embeds.shape[0]
|
| 489 |
+
seq_length = input_shape[-1]
|
| 490 |
+
|
| 491 |
+
device = input_ids.device if input_ids is not None else input_embeds.device
|
| 492 |
+
|
| 493 |
+
if self.gradient_checkpointing and self.training:
|
| 494 |
+
if use_cache:
|
| 495 |
+
logger.warning_once(
|
| 496 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 497 |
+
)
|
| 498 |
+
use_cache = False
|
| 499 |
+
|
| 500 |
+
if use_cache and past_key_values is None:
|
| 501 |
+
past_key_values = DynamicCache(config=self.config)
|
| 502 |
+
if use_cache and isinstance(past_key_values, tuple):
|
| 503 |
+
logger.warning_once(
|
| 504 |
+
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
|
| 505 |
+
"You should pass an instance of `DynamicCache` instead, e.g. "
|
| 506 |
+
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
|
| 507 |
+
)
|
| 508 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 509 |
+
|
| 510 |
+
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 511 |
+
|
| 512 |
+
if position_ids is None:
|
| 513 |
+
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
|
| 514 |
+
position_ids = position_ids.unsqueeze(0) # shape (1, seq_length)
|
| 515 |
+
|
| 516 |
+
position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd)
|
| 517 |
+
|
| 518 |
+
# Attention mask.
|
| 519 |
+
if attention_mask is not None:
|
| 520 |
+
if batch_size <= 0:
|
| 521 |
+
raise ValueError("batch_size has to be defined and > 0")
|
| 522 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 523 |
+
attention_mask = attention_mask if 0 in attention_mask else None
|
| 524 |
+
else:
|
| 525 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
| 526 |
+
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
|
| 527 |
+
# from_seq_length is 1 to easily broadcast
|
| 528 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
|
| 529 |
+
|
| 530 |
+
# Prepare head mask if needed
|
| 531 |
+
# 1.0 in head_mask indicate we keep the head
|
| 532 |
+
# attention_probs has shape bsz x num_heads x N x N
|
| 533 |
+
# head_mask has shape num_layers x batch x num_heads x N x N
|
| 534 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
| 535 |
+
|
| 536 |
+
hidden_states = self.drop(input_embeds + position_embeds)
|
| 537 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
| 538 |
+
|
| 539 |
+
all_self_attentions = () if output_attentions else None
|
| 540 |
+
all_hidden_states = () if output_hidden_states else None
|
| 541 |
+
|
| 542 |
+
for i, block in enumerate(self.layers):
|
| 543 |
+
if output_hidden_states:
|
| 544 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 545 |
+
|
| 546 |
+
outputs = block(
|
| 547 |
+
hidden_states,
|
| 548 |
+
past_key_values=past_key_values,
|
| 549 |
+
attention_mask=attention_mask,
|
| 550 |
+
head_mask=head_mask[i],
|
| 551 |
+
use_cache=use_cache,
|
| 552 |
+
output_attentions=output_attentions,
|
| 553 |
+
cache_position=cache_position,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
hidden_states = outputs[0]
|
| 557 |
+
|
| 558 |
+
if output_attentions:
|
| 559 |
+
all_self_attentions = all_self_attentions + (outputs[1],)
|
| 560 |
+
|
| 561 |
+
hidden_states = self.layernorm_final(hidden_states)
|
| 562 |
+
|
| 563 |
+
hidden_states = hidden_states.view(output_shape)
|
| 564 |
+
|
| 565 |
+
# Add last hidden state
|
| 566 |
+
if output_hidden_states:
|
| 567 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 568 |
+
|
| 569 |
+
logits = self.lm_head(hidden_states)
|
| 570 |
+
|
| 571 |
+
if not return_dict:
|
| 572 |
+
return tuple(
|
| 573 |
+
v for v in [None, logits, past_key_values, all_hidden_states, all_self_attentions] if v is not None
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
return CausalLMOutputWithPast(
|
| 577 |
+
loss=loss,
|
| 578 |
+
logits=logits,
|
| 579 |
+
past_key_values=past_key_values,
|
| 580 |
+
hidden_states=all_hidden_states,
|
| 581 |
+
attentions=all_self_attentions,
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
@auto_docstring(
|
| 586 |
+
custom_intro="""
|
| 587 |
+
Bark semantic (or text) model. It shares the same architecture as the coarse model.
|
| 588 |
+
It is a GPT-2 like autoregressive model with a language modeling head on top.
|
| 589 |
+
"""
|
| 590 |
+
)
|
| 591 |
+
class BarkSemanticModel(BarkCausalModel):
|
| 592 |
+
base_model_prefix = "semantic"
|
| 593 |
+
config: BarkSemanticConfig
|
| 594 |
+
|
| 595 |
+
def generate(
|
| 596 |
+
self,
|
| 597 |
+
input_ids: torch.Tensor,
|
| 598 |
+
semantic_generation_config: Optional[BarkSemanticGenerationConfig] = None,
|
| 599 |
+
history_prompt: Optional[dict[str, torch.Tensor]] = None,
|
| 600 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 601 |
+
**kwargs,
|
| 602 |
+
) -> torch.LongTensor:
|
| 603 |
+
"""
|
| 604 |
+
Generates text semantic tokens from an input prompt and an additional optional `Bark` speaker prompt.
|
| 605 |
+
|
| 606 |
+
Args:
|
| 607 |
+
input_ids (`Optional[torch.Tensor]` of shape (batch_size, seq_len), *optional*):
|
| 608 |
+
Input ids, i.e tokenized input sentences. Will be truncated up to
|
| 609 |
+
semantic_generation_config.max_input_semantic_length tokens. Note that the output audios will be as
|
| 610 |
+
long as the longest generation among the batch.
|
| 611 |
+
semantic_generation_config (`BarkSemanticGenerationConfig`):
|
| 612 |
+
Generation config indicating how to generate the semantic tokens.
|
| 613 |
+
history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
|
| 614 |
+
Optional `Bark` speaker prompt.
|
| 615 |
+
attention_mask (`Optional[torch.Tensor]`, *optional*):
|
| 616 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 617 |
+
|
| 618 |
+
- 1 for tokens that are **not masked**,
|
| 619 |
+
- 0 for tokens that are **masked**.
|
| 620 |
+
|
| 621 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 622 |
+
Returns:
|
| 623 |
+
torch.LongTensor: Output semantic tokens.
|
| 624 |
+
"""
|
| 625 |
+
if semantic_generation_config is None:
|
| 626 |
+
raise ValueError("`semantic_generation_config` has to be provided")
|
| 627 |
+
|
| 628 |
+
batch_size = input_ids.shape[0]
|
| 629 |
+
|
| 630 |
+
max_input_semantic_length = semantic_generation_config.max_input_semantic_length
|
| 631 |
+
|
| 632 |
+
input_ids = input_ids + semantic_generation_config.text_encoding_offset
|
| 633 |
+
|
| 634 |
+
if attention_mask is not None:
|
| 635 |
+
input_ids = input_ids.masked_fill((1 - attention_mask).bool(), semantic_generation_config.text_pad_token)
|
| 636 |
+
|
| 637 |
+
if history_prompt is not None:
|
| 638 |
+
semantic_history = history_prompt["semantic_prompt"][-max_input_semantic_length:]
|
| 639 |
+
semantic_history = nn.functional.pad(
|
| 640 |
+
semantic_history,
|
| 641 |
+
(0, max_input_semantic_length - len(semantic_history)),
|
| 642 |
+
value=semantic_generation_config.semantic_pad_token,
|
| 643 |
+
mode="constant",
|
| 644 |
+
)
|
| 645 |
+
else:
|
| 646 |
+
semantic_history = torch.tensor(
|
| 647 |
+
[semantic_generation_config.semantic_pad_token] * max_input_semantic_length, dtype=torch.int
|
| 648 |
+
).to(self.device)
|
| 649 |
+
|
| 650 |
+
semantic_history = torch.repeat_interleave(semantic_history[None], batch_size, dim=0)
|
| 651 |
+
|
| 652 |
+
infer_array = torch.tensor(
|
| 653 |
+
[[semantic_generation_config.semantic_infer_token]] * batch_size, dtype=torch.int
|
| 654 |
+
).to(self.device)
|
| 655 |
+
|
| 656 |
+
input_embeds = torch.cat(
|
| 657 |
+
[
|
| 658 |
+
self.input_embeds_layer(input_ids[:, :max_input_semantic_length])
|
| 659 |
+
+ self.input_embeds_layer(semantic_history[:, : max_input_semantic_length + 1]),
|
| 660 |
+
self.input_embeds_layer(infer_array),
|
| 661 |
+
],
|
| 662 |
+
dim=1,
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
tokens_to_suppress = list(
|
| 666 |
+
range(semantic_generation_config.semantic_vocab_size, semantic_generation_config.semantic_pad_token)
|
| 667 |
+
)
|
| 668 |
+
tokens_to_suppress.extend(
|
| 669 |
+
list(range(semantic_generation_config.semantic_pad_token + 1, self.config.output_vocab_size))
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress, device=input_ids.device)
|
| 673 |
+
|
| 674 |
+
min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p)
|
| 675 |
+
early_stopping_logits_processor = BarkEosPrioritizerLogitsProcessor(
|
| 676 |
+
eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p, device=input_ids.device
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
# pass input_ids in order to stay consistent with the transformers generate method even though it is not used
|
| 680 |
+
# (except to get the input seq_len - that's why we keep the first 257 tokens)
|
| 681 |
+
semantic_output = super().generate(
|
| 682 |
+
torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int, device=self.device),
|
| 683 |
+
input_embeds=input_embeds,
|
| 684 |
+
logits_processor=[suppress_tokens_logits_processor, early_stopping_logits_processor],
|
| 685 |
+
generation_config=semantic_generation_config,
|
| 686 |
+
**kwargs,
|
| 687 |
+
) # size: 10048
|
| 688 |
+
|
| 689 |
+
# take the generated semantic tokens
|
| 690 |
+
semantic_output = semantic_output[:, max_input_semantic_length + 1 :]
|
| 691 |
+
|
| 692 |
+
return semantic_output
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
@auto_docstring(
|
| 696 |
+
custom_intro="""
|
| 697 |
+
Bark coarse acoustics model.
|
| 698 |
+
It shares the same architecture as the semantic (or text) model. It is a GPT-2 like autoregressive model with a
|
| 699 |
+
language modeling head on top.
|
| 700 |
+
"""
|
| 701 |
+
)
|
| 702 |
+
class BarkCoarseModel(BarkCausalModel):
|
| 703 |
+
base_model_prefix = "coarse_acoustics"
|
| 704 |
+
config: BarkCoarseConfig
|
| 705 |
+
|
| 706 |
+
def preprocess_histories(
|
| 707 |
+
self,
|
| 708 |
+
max_coarse_history: int,
|
| 709 |
+
semantic_to_coarse_ratio: int,
|
| 710 |
+
batch_size: int,
|
| 711 |
+
semantic_generation_config: int,
|
| 712 |
+
codebook_size: int,
|
| 713 |
+
history_prompt: Optional[dict[str, torch.Tensor]] = None,
|
| 714 |
+
):
|
| 715 |
+
"""
|
| 716 |
+
Preprocess the optional `Bark` speaker prompts before `self.generate`.
|
| 717 |
+
|
| 718 |
+
Args:
|
| 719 |
+
max_coarse_history (`int`):
|
| 720 |
+
Maximum size of coarse tokens used.
|
| 721 |
+
semantic_to_coarse_ratio (`int`):
|
| 722 |
+
Ratio of semantic to coarse frequency
|
| 723 |
+
batch_size (`int`):
|
| 724 |
+
Batch size, i.e the number of samples.
|
| 725 |
+
semantic_generation_config (`BarkSemanticGenerationConfig`):
|
| 726 |
+
Generation config indicating how to generate the semantic tokens.
|
| 727 |
+
codebook_size (`int`):
|
| 728 |
+
Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
|
| 729 |
+
history_prompt (`Optional[dict[str,torch.Tensor]]`):
|
| 730 |
+
Optional `Bark` speaker prompt.
|
| 731 |
+
Returns: Returns:
|
| 732 |
+
`tuple(torch.FloatTensor)`:
|
| 733 |
+
- **x_semantic_history** (`torch.FloatTensor` -- Processed semantic speaker prompt.
|
| 734 |
+
- **x_coarse_history** (`torch.FloatTensor`) -- Processed coarse speaker prompt.
|
| 735 |
+
"""
|
| 736 |
+
if history_prompt is not None:
|
| 737 |
+
x_semantic_history = torch.repeat_interleave(history_prompt["semantic_prompt"][None], batch_size, dim=0)
|
| 738 |
+
# clone to avoid modifying history_prompt.coarse_prompt
|
| 739 |
+
x_coarse_history = history_prompt["coarse_prompt"].clone()
|
| 740 |
+
|
| 741 |
+
# offset x_coarse_history
|
| 742 |
+
if codebook_size is not None:
|
| 743 |
+
for n in range(1, x_coarse_history.shape[0]):
|
| 744 |
+
# offset
|
| 745 |
+
x_coarse_history[n, :] += codebook_size * n
|
| 746 |
+
|
| 747 |
+
# flatten x_coarse_history
|
| 748 |
+
x_coarse_history = torch.transpose(x_coarse_history, 0, 1).reshape(-1)
|
| 749 |
+
|
| 750 |
+
x_coarse_history = x_coarse_history + semantic_generation_config.semantic_vocab_size
|
| 751 |
+
|
| 752 |
+
x_coarse_history = torch.repeat_interleave(x_coarse_history[None], batch_size, dim=0)
|
| 753 |
+
# e.g: after SEMANTIC_VOCAB_SIZE (10000), 1024 tokens dedicated to first codebook, 1024 next tokens
|
| 754 |
+
# dedicated to second codebook.
|
| 755 |
+
|
| 756 |
+
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
|
| 757 |
+
# trim histories correctly
|
| 758 |
+
n_semantic_hist_provided = min(
|
| 759 |
+
[
|
| 760 |
+
max_semantic_history,
|
| 761 |
+
x_semantic_history.shape[1] - x_semantic_history.shape[1] % 2,
|
| 762 |
+
int(np.floor(x_coarse_history.shape[1] / semantic_to_coarse_ratio)),
|
| 763 |
+
]
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio))
|
| 767 |
+
|
| 768 |
+
x_semantic_history = x_semantic_history[:, -n_semantic_hist_provided:].int()
|
| 769 |
+
x_coarse_history = x_coarse_history[:, -n_coarse_hist_provided:].int()
|
| 770 |
+
# bit of a hack for time alignment (sounds better) - from Bark original implementation
|
| 771 |
+
x_coarse_history = x_coarse_history[:, :-2]
|
| 772 |
+
|
| 773 |
+
else:
|
| 774 |
+
# shape: (batch_size, 0)
|
| 775 |
+
x_semantic_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device)
|
| 776 |
+
x_coarse_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device)
|
| 777 |
+
|
| 778 |
+
return x_semantic_history, x_coarse_history
|
| 779 |
+
|
| 780 |
+
def generate(
|
| 781 |
+
self,
|
| 782 |
+
semantic_output: torch.Tensor,
|
| 783 |
+
semantic_generation_config: Optional[BarkSemanticGenerationConfig] = None,
|
| 784 |
+
coarse_generation_config: Optional[BarkCoarseGenerationConfig] = None,
|
| 785 |
+
codebook_size: int = 1024,
|
| 786 |
+
history_prompt: Optional[dict[str, torch.Tensor]] = None,
|
| 787 |
+
return_output_lengths: Optional[bool] = None,
|
| 788 |
+
**kwargs,
|
| 789 |
+
) -> Union[torch.LongTensor, tuple[torch.LongTensor, torch.LongTensor]]:
|
| 790 |
+
"""
|
| 791 |
+
Generates coarse acoustics tokens from input text semantic tokens and an additional optional `Bark` speaker
|
| 792 |
+
prompt.
|
| 793 |
+
|
| 794 |
+
Args:
|
| 795 |
+
semantic_output (`torch.Tensor` of shape (batch_size, seq_len), *optional*):
|
| 796 |
+
Input text semantic ids, i.e the output of `BarkSemanticModel.generate`.
|
| 797 |
+
semantic_generation_config (`BarkSemanticGenerationConfig`):
|
| 798 |
+
Generation config indicating how to generate the semantic tokens.
|
| 799 |
+
coarse_generation_config (`BarkCoarseGenerationConfig`):
|
| 800 |
+
Generation config indicating how to generate the coarse tokens.
|
| 801 |
+
codebook_size (`int`, *optional*, defaults to 1024):
|
| 802 |
+
Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
|
| 803 |
+
history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
|
| 804 |
+
Optional `Bark` speaker prompt.
|
| 805 |
+
return_output_lengths (`bool`, *optional*):
|
| 806 |
+
Whether or not to return the output lengths. Useful when batching.
|
| 807 |
+
Returns:
|
| 808 |
+
By default:
|
| 809 |
+
torch.LongTensor: Output coarse acoustics tokens.
|
| 810 |
+
If `return_output_lengths=True`:
|
| 811 |
+
`Tuple(torch.Tensor, torch.Tensor): The output coarse acoustics tokens, and the length of each sample
|
| 812 |
+
of the batch.
|
| 813 |
+
"""
|
| 814 |
+
|
| 815 |
+
if semantic_generation_config is None:
|
| 816 |
+
raise ValueError("`semantic_generation_config` has to be provided")
|
| 817 |
+
|
| 818 |
+
if coarse_generation_config is None:
|
| 819 |
+
raise ValueError("`coarse_generation_config` has to be provided")
|
| 820 |
+
|
| 821 |
+
max_coarse_input_length = coarse_generation_config.max_coarse_input_length
|
| 822 |
+
max_coarse_history = coarse_generation_config.max_coarse_history
|
| 823 |
+
sliding_window_len = coarse_generation_config.sliding_window_len
|
| 824 |
+
|
| 825 |
+
# replace semantic_pad_token (eos_tok and pad_tok here) with coarse_semantic_pad_token i.e the pad_token
|
| 826 |
+
# used in the next model
|
| 827 |
+
semantic_output.masked_fill_(
|
| 828 |
+
semantic_output == semantic_generation_config.semantic_pad_token,
|
| 829 |
+
coarse_generation_config.coarse_semantic_pad_token,
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
semantic_to_coarse_ratio = (
|
| 833 |
+
coarse_generation_config.coarse_rate_hz
|
| 834 |
+
/ semantic_generation_config.semantic_rate_hz
|
| 835 |
+
* coarse_generation_config.n_coarse_codebooks
|
| 836 |
+
)
|
| 837 |
+
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
|
| 838 |
+
|
| 839 |
+
output_lengths = (semantic_output != coarse_generation_config.coarse_semantic_pad_token).sum(1)
|
| 840 |
+
output_lengths = torch.floor(
|
| 841 |
+
output_lengths * semantic_to_coarse_ratio / coarse_generation_config.n_coarse_codebooks
|
| 842 |
+
)
|
| 843 |
+
output_lengths = torch.round(output_lengths * coarse_generation_config.n_coarse_codebooks).int()
|
| 844 |
+
|
| 845 |
+
max_generated_len = torch.max(output_lengths).item()
|
| 846 |
+
|
| 847 |
+
batch_size = semantic_output.shape[0]
|
| 848 |
+
|
| 849 |
+
x_semantic_history, x_coarse = self.preprocess_histories(
|
| 850 |
+
history_prompt=history_prompt,
|
| 851 |
+
max_coarse_history=max_coarse_history,
|
| 852 |
+
semantic_to_coarse_ratio=semantic_to_coarse_ratio,
|
| 853 |
+
batch_size=batch_size,
|
| 854 |
+
semantic_generation_config=semantic_generation_config,
|
| 855 |
+
codebook_size=codebook_size,
|
| 856 |
+
)
|
| 857 |
+
base_semantic_idx = x_semantic_history.shape[1]
|
| 858 |
+
|
| 859 |
+
semantic_output = torch.hstack([x_semantic_history, semantic_output])
|
| 860 |
+
|
| 861 |
+
n_window_steps = int(np.ceil(max_generated_len / sliding_window_len))
|
| 862 |
+
|
| 863 |
+
total_generated_len = 0
|
| 864 |
+
|
| 865 |
+
len_coarse_history = x_coarse.shape[1]
|
| 866 |
+
|
| 867 |
+
for _ in range(n_window_steps):
|
| 868 |
+
semantic_idx = base_semantic_idx + int(round(total_generated_len / semantic_to_coarse_ratio))
|
| 869 |
+
|
| 870 |
+
# pad from right side
|
| 871 |
+
input_coarse = semantic_output[:, np.max([0, semantic_idx - max_semantic_history]) :]
|
| 872 |
+
input_coarse = input_coarse[:, :max_coarse_input_length]
|
| 873 |
+
input_coarse = F.pad(
|
| 874 |
+
input_coarse,
|
| 875 |
+
(0, max_coarse_input_length - input_coarse.shape[-1]),
|
| 876 |
+
"constant",
|
| 877 |
+
coarse_generation_config.coarse_semantic_pad_token,
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
input_coarse = torch.hstack(
|
| 881 |
+
[
|
| 882 |
+
input_coarse,
|
| 883 |
+
torch.tensor([[coarse_generation_config.coarse_infer_token]] * batch_size, device=self.device),
|
| 884 |
+
x_coarse[:, -max_coarse_history:],
|
| 885 |
+
]
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
alternatingLogitsProcessor = AlternatingCodebooksLogitsProcessor(
|
| 889 |
+
input_coarse.shape[1],
|
| 890 |
+
semantic_generation_config.semantic_vocab_size,
|
| 891 |
+
codebook_size,
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
output_coarse = super().generate(
|
| 895 |
+
input_coarse,
|
| 896 |
+
logits_processor=[alternatingLogitsProcessor],
|
| 897 |
+
max_new_tokens=min(sliding_window_len, max_generated_len - total_generated_len),
|
| 898 |
+
generation_config=coarse_generation_config,
|
| 899 |
+
**kwargs,
|
| 900 |
+
)
|
| 901 |
+
|
| 902 |
+
input_coarse_len = input_coarse.shape[1]
|
| 903 |
+
|
| 904 |
+
x_coarse = torch.hstack([x_coarse, output_coarse[:, input_coarse_len:]])
|
| 905 |
+
total_generated_len = x_coarse.shape[1] - len_coarse_history
|
| 906 |
+
|
| 907 |
+
del output_coarse
|
| 908 |
+
|
| 909 |
+
coarse_output = x_coarse[:, len_coarse_history:]
|
| 910 |
+
|
| 911 |
+
if return_output_lengths:
|
| 912 |
+
return coarse_output, output_lengths
|
| 913 |
+
|
| 914 |
+
return coarse_output
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
@auto_docstring(
|
| 918 |
+
custom_intro="""
|
| 919 |
+
Bark fine acoustics model. It is a non-causal GPT-like model with `config.n_codes_total` embedding layers and
|
| 920 |
+
language modeling heads, one for each codebook.
|
| 921 |
+
"""
|
| 922 |
+
)
|
| 923 |
+
class BarkFineModel(BarkPreTrainedModel):
|
| 924 |
+
base_model_prefix = "fine_acoustics"
|
| 925 |
+
config: BarkFineConfig
|
| 926 |
+
main_input_name = "codebook_idx"
|
| 927 |
+
|
| 928 |
+
def __init__(self, config):
|
| 929 |
+
# non-causal gpt-like model with one embedding layer and one lm_head for each codebook of Encodec
|
| 930 |
+
super().__init__(config)
|
| 931 |
+
self.config = config
|
| 932 |
+
|
| 933 |
+
# initialize a modified non causal GPT-like model
|
| 934 |
+
# note that for there is one embedding layer and one lm_head for each codebook of Encodec
|
| 935 |
+
self.input_embeds_layers = nn.ModuleList(
|
| 936 |
+
[nn.Embedding(config.input_vocab_size, config.hidden_size) for _ in range(config.n_codes_total)]
|
| 937 |
+
)
|
| 938 |
+
self.position_embeds_layer = nn.Embedding(config.block_size, config.hidden_size)
|
| 939 |
+
|
| 940 |
+
self.drop = nn.Dropout(config.dropout)
|
| 941 |
+
|
| 942 |
+
self.layers = nn.ModuleList(
|
| 943 |
+
[BarkBlock(config, is_causal=False, layer_idx=i) for i in range(config.num_layers)]
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
self.layernorm_final = nn.LayerNorm(config.hidden_size)
|
| 947 |
+
|
| 948 |
+
self.lm_heads = nn.ModuleList(
|
| 949 |
+
[
|
| 950 |
+
nn.Linear(config.hidden_size, config.output_vocab_size, bias=False)
|
| 951 |
+
for _ in range(config.n_codes_given, config.n_codes_total)
|
| 952 |
+
]
|
| 953 |
+
)
|
| 954 |
+
self.gradient_checkpointing = False
|
| 955 |
+
self.n_codes_total = config.n_codes_total
|
| 956 |
+
|
| 957 |
+
# Initialize weights and apply final processing
|
| 958 |
+
self.post_init()
|
| 959 |
+
|
| 960 |
+
def get_input_embeddings(self):
|
| 961 |
+
# one embedding layers for each codebook
|
| 962 |
+
return self.input_embeds_layers
|
| 963 |
+
|
| 964 |
+
def set_input_embeddings(self, new_embeddings):
|
| 965 |
+
# one embedding layers for each codebook
|
| 966 |
+
self.input_embeds_layers = new_embeddings
|
| 967 |
+
|
| 968 |
+
def get_output_embeddings(self):
|
| 969 |
+
# one lm_head for each codebook
|
| 970 |
+
return self.lm_heads
|
| 971 |
+
|
| 972 |
+
def set_output_embeddings(self, new_output_embeddings):
|
| 973 |
+
# one lm_head for each codebook
|
| 974 |
+
self.lm_heads = new_output_embeddings
|
| 975 |
+
|
| 976 |
+
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
|
| 977 |
+
old_embeddings_list = self.get_input_embeddings()
|
| 978 |
+
new_embeddings_list = nn.ModuleList(
|
| 979 |
+
[
|
| 980 |
+
self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing)
|
| 981 |
+
for old_embeddings in old_embeddings_list
|
| 982 |
+
]
|
| 983 |
+
)
|
| 984 |
+
self.set_input_embeddings(new_embeddings_list)
|
| 985 |
+
new_num_tokens = new_embeddings_list[0].weight.shape[0]
|
| 986 |
+
|
| 987 |
+
# if word embeddings are not tied, make sure that lm head is resized as well
|
| 988 |
+
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
|
| 989 |
+
old_lm_head_list = self.get_output_embeddings()
|
| 990 |
+
new_lm_head_list = nn.ModuleList(
|
| 991 |
+
[self._get_resized_lm_head(old_lm_head, new_num_tokens) for old_lm_head in old_lm_head_list]
|
| 992 |
+
)
|
| 993 |
+
self.set_output_embeddings(new_lm_head_list)
|
| 994 |
+
|
| 995 |
+
return self.get_input_embeddings()
|
| 996 |
+
|
| 997 |
+
def resize_token_embeddings(
|
| 998 |
+
self,
|
| 999 |
+
new_num_tokens: Optional[int] = None,
|
| 1000 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 1001 |
+
mean_resizing: bool = True,
|
| 1002 |
+
) -> nn.Embedding:
|
| 1003 |
+
"""
|
| 1004 |
+
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
|
| 1005 |
+
|
| 1006 |
+
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
| 1007 |
+
|
| 1008 |
+
Arguments:
|
| 1009 |
+
new_num_tokens (`int`, *optional*):
|
| 1010 |
+
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
|
| 1011 |
+
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
|
| 1012 |
+
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
|
| 1013 |
+
pad_to_multiple_of (`int`, *optional*):
|
| 1014 |
+
If set will pad the embedding matrix to a multiple of the provided value.
|
| 1015 |
+
|
| 1016 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
| 1017 |
+
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
|
| 1018 |
+
details about this, or help on choosing the correct value for resizing, refer to this guide:
|
| 1019 |
+
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
|
| 1020 |
+
mean_resizing (`bool`):
|
| 1021 |
+
Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
|
| 1022 |
+
covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
|
| 1023 |
+
|
| 1024 |
+
Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
|
| 1025 |
+
where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the
|
| 1026 |
+
old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
|
| 1027 |
+
Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
|
| 1028 |
+
|
| 1029 |
+
Return:
|
| 1030 |
+
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
|
| 1031 |
+
"""
|
| 1032 |
+
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
|
| 1033 |
+
if new_num_tokens is None and pad_to_multiple_of is None:
|
| 1034 |
+
return model_embeds
|
| 1035 |
+
|
| 1036 |
+
# Update base model and current model config
|
| 1037 |
+
self.config.output_vocab_size = model_embeds[0].weight.shape[0]
|
| 1038 |
+
self.config.vocab_size = model_embeds[0].weight.shape[0]
|
| 1039 |
+
self.output_vocab_size = model_embeds[0].weight.shape[0]
|
| 1040 |
+
self.vocab_size = model_embeds[0].weight.shape[0]
|
| 1041 |
+
|
| 1042 |
+
# Tie weights again if needed
|
| 1043 |
+
self.tie_weights()
|
| 1044 |
+
|
| 1045 |
+
return model_embeds
|
| 1046 |
+
|
| 1047 |
+
def _tie_weights(self):
|
| 1048 |
+
if getattr(self.config, "tie_word_embeddings", True):
|
| 1049 |
+
self._tied_weights_keys = []
|
| 1050 |
+
output_embeddings = self.get_output_embeddings()
|
| 1051 |
+
input_embeddings = self.get_input_embeddings()
|
| 1052 |
+
|
| 1053 |
+
for i in range(self.config.n_codes_total - self.config.n_codes_given):
|
| 1054 |
+
# self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight
|
| 1055 |
+
self._tie_or_clone_weights(output_embeddings[i], input_embeddings[i + 1])
|
| 1056 |
+
self._tied_weights_keys.append(f"lm_heads.{i}.weight")
|
| 1057 |
+
|
| 1058 |
+
def tie_weights(self):
|
| 1059 |
+
"""
|
| 1060 |
+
Tie the weights between the input embeddings list and the output embeddings list.
|
| 1061 |
+
|
| 1062 |
+
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
|
| 1063 |
+
weights instead.
|
| 1064 |
+
"""
|
| 1065 |
+
for module in self.modules():
|
| 1066 |
+
if hasattr(module, "_tie_weights"):
|
| 1067 |
+
module._tie_weights()
|
| 1068 |
+
|
| 1069 |
+
@auto_docstring
|
| 1070 |
+
def forward(
|
| 1071 |
+
self,
|
| 1072 |
+
codebook_idx: int, # an additional idx corresponding to the id of the codebook that will be predicted
|
| 1073 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1074 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1075 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1076 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1077 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1078 |
+
input_embeds: Optional[torch.Tensor] = None,
|
| 1079 |
+
output_attentions: Optional[bool] = None,
|
| 1080 |
+
output_hidden_states: Optional[bool] = None,
|
| 1081 |
+
return_dict: Optional[bool] = None,
|
| 1082 |
+
) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
|
| 1083 |
+
r"""
|
| 1084 |
+
codebook_idx (`int`):
|
| 1085 |
+
Index of the codebook that will be predicted.
|
| 1086 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1087 |
+
NOT IMPLEMENTED YET.
|
| 1088 |
+
input_embeds (`torch.FloatTensor` of shape `(batch_size, input_sequence_length, hidden_size)`, *optional*):
|
| 1089 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. If
|
| 1090 |
+
`past_key_values` is used, optionally only the last `input_embeds` have to be input (see
|
| 1091 |
+
`past_key_values`). This is useful if you want more control over how to convert `input_ids` indices into
|
| 1092 |
+
associated vectors than the model's internal embedding lookup matrix.
|
| 1093 |
+
"""
|
| 1094 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1095 |
+
output_hidden_states = (
|
| 1096 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1097 |
+
)
|
| 1098 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1099 |
+
|
| 1100 |
+
loss = None
|
| 1101 |
+
if labels is not None:
|
| 1102 |
+
raise NotImplementedError("Training is not implemented yet")
|
| 1103 |
+
|
| 1104 |
+
if codebook_idx == 0:
|
| 1105 |
+
raise ValueError("Cannot predict 0th codebook - 0th codebook should be predicted by the coarse model")
|
| 1106 |
+
|
| 1107 |
+
if input_ids is not None and input_embeds is not None:
|
| 1108 |
+
raise ValueError("You cannot specify both input_ids and input_embeds at the same time")
|
| 1109 |
+
|
| 1110 |
+
if input_ids is None and input_embeds is None:
|
| 1111 |
+
raise ValueError("You have to specify either input_ids or input_embeds")
|
| 1112 |
+
|
| 1113 |
+
if input_ids is not None:
|
| 1114 |
+
# the input_embeddings are the sum of the j previous codebooks embeddings before
|
| 1115 |
+
# the current codebook_idx codebook
|
| 1116 |
+
|
| 1117 |
+
# forward the GPT model itself
|
| 1118 |
+
input_embeds = [
|
| 1119 |
+
input_embeds_layer(input_ids[:, :, i]).unsqueeze(-1)
|
| 1120 |
+
for i, input_embeds_layer in enumerate(self.input_embeds_layers)
|
| 1121 |
+
] # token embeddings of shape (b, t, n_embd)
|
| 1122 |
+
input_embeds = torch.cat(input_embeds, dim=-1)
|
| 1123 |
+
input_embeds = input_embeds[:, :, :, : codebook_idx + 1].sum(dim=-1)
|
| 1124 |
+
|
| 1125 |
+
input_shape = input_embeds.size()[:-1]
|
| 1126 |
+
batch_size = input_embeds.shape[0]
|
| 1127 |
+
seq_length = input_shape[1]
|
| 1128 |
+
|
| 1129 |
+
device = input_ids.device if input_ids is not None else input_embeds.device
|
| 1130 |
+
|
| 1131 |
+
if position_ids is None:
|
| 1132 |
+
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
|
| 1133 |
+
position_ids = position_ids.unsqueeze(0) # shape (1, seq_length)
|
| 1134 |
+
|
| 1135 |
+
position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd)
|
| 1136 |
+
|
| 1137 |
+
# Attention mask.
|
| 1138 |
+
if attention_mask is not None:
|
| 1139 |
+
if batch_size <= 0:
|
| 1140 |
+
raise ValueError("batch_size has to be defined and > 0")
|
| 1141 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 1142 |
+
attention_mask = attention_mask if 0 in attention_mask else None
|
| 1143 |
+
else:
|
| 1144 |
+
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
|
| 1145 |
+
# from_seq_length is 1 to easily broadcast
|
| 1146 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
|
| 1147 |
+
|
| 1148 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
| 1149 |
+
|
| 1150 |
+
hidden_states = self.drop(input_embeds + position_embeds)
|
| 1151 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
| 1152 |
+
|
| 1153 |
+
all_self_attentions = () if output_attentions else None
|
| 1154 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1155 |
+
|
| 1156 |
+
for i, block in enumerate(self.layers):
|
| 1157 |
+
if output_hidden_states:
|
| 1158 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1159 |
+
|
| 1160 |
+
outputs = block(
|
| 1161 |
+
hidden_states,
|
| 1162 |
+
attention_mask=attention_mask,
|
| 1163 |
+
head_mask=head_mask[i],
|
| 1164 |
+
output_attentions=output_attentions,
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
hidden_states = outputs[0]
|
| 1168 |
+
|
| 1169 |
+
if output_attentions:
|
| 1170 |
+
all_self_attentions = all_self_attentions + (outputs[1],)
|
| 1171 |
+
|
| 1172 |
+
hidden_states = self.layernorm_final(hidden_states)
|
| 1173 |
+
hidden_states = hidden_states.view(output_shape)
|
| 1174 |
+
|
| 1175 |
+
# Add last hidden state
|
| 1176 |
+
if output_hidden_states:
|
| 1177 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1178 |
+
|
| 1179 |
+
logits = self.lm_heads[codebook_idx - self.config.n_codes_given](hidden_states)
|
| 1180 |
+
|
| 1181 |
+
if not return_dict:
|
| 1182 |
+
return tuple(v for v in [None, logits, all_hidden_states, all_self_attentions] if v is not None)
|
| 1183 |
+
|
| 1184 |
+
return MaskedLMOutput(
|
| 1185 |
+
loss=loss,
|
| 1186 |
+
logits=logits,
|
| 1187 |
+
hidden_states=all_hidden_states,
|
| 1188 |
+
attentions=all_self_attentions,
|
| 1189 |
+
)
|
| 1190 |
+
|
| 1191 |
+
@torch.no_grad()
|
| 1192 |
+
def generate(
|
| 1193 |
+
self,
|
| 1194 |
+
coarse_output: torch.Tensor,
|
| 1195 |
+
semantic_generation_config: Optional[BarkSemanticGenerationConfig] = None,
|
| 1196 |
+
coarse_generation_config: Optional[BarkCoarseGenerationConfig] = None,
|
| 1197 |
+
fine_generation_config: BarkFineGenerationConfig = None,
|
| 1198 |
+
codebook_size: int = 1024,
|
| 1199 |
+
history_prompt: Optional[dict[str, torch.Tensor]] = None,
|
| 1200 |
+
**kwargs,
|
| 1201 |
+
) -> torch.LongTensor:
|
| 1202 |
+
"""
|
| 1203 |
+
Generates fine acoustics tokens from input coarse acoustics tokens and an additional optional `Bark` speaker
|
| 1204 |
+
prompt.
|
| 1205 |
+
|
| 1206 |
+
Args:
|
| 1207 |
+
coarse_output (`torch.Tensor` of shape (batch_size, seq_len)):
|
| 1208 |
+
Input coarse acoustics ids, i.e the output of `BarkCoarseModel.generate`.
|
| 1209 |
+
semantic_generation_config (`BarkSemanticGenerationConfig`):
|
| 1210 |
+
Generation config indicating how to generate the semantic tokens.
|
| 1211 |
+
coarse_generation_config (`BarkCoarseGenerationConfig`):
|
| 1212 |
+
Generation config indicating how to generate the coarse tokens.
|
| 1213 |
+
fine_generation_config (`BarkFineGenerationConfig`):
|
| 1214 |
+
Generation config indicating how to generate the fine tokens.
|
| 1215 |
+
codebook_size (`int`, *optional*, defaults to 1024):
|
| 1216 |
+
Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
|
| 1217 |
+
history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
|
| 1218 |
+
Optional `Bark` speaker prompt.
|
| 1219 |
+
Returns:
|
| 1220 |
+
torch.LongTensor: Output fine acoustics tokens.
|
| 1221 |
+
"""
|
| 1222 |
+
if semantic_generation_config is None:
|
| 1223 |
+
raise ValueError("`semantic_generation_config` has to be provided")
|
| 1224 |
+
|
| 1225 |
+
if coarse_generation_config is None:
|
| 1226 |
+
raise ValueError("`coarse_generation_config` has to be provided")
|
| 1227 |
+
|
| 1228 |
+
if fine_generation_config is None:
|
| 1229 |
+
raise ValueError("`fine_generation_config` has to be provided")
|
| 1230 |
+
|
| 1231 |
+
# since we don't really use GenerationConfig through the fine model (autoencoder)
|
| 1232 |
+
# and since only temperature is used from the classic GenerationConfig parameters
|
| 1233 |
+
# manually impose the kwargs priority over the generation config
|
| 1234 |
+
temperature = kwargs.get("temperature", fine_generation_config.temperature)
|
| 1235 |
+
|
| 1236 |
+
max_fine_history_length = fine_generation_config.max_fine_history_length
|
| 1237 |
+
max_fine_input_length = fine_generation_config.max_fine_input_length
|
| 1238 |
+
|
| 1239 |
+
# shape: (batch, n_coarse_codebooks * seq_len)
|
| 1240 |
+
# new_shape: (batch, seq_len, n_coarse_codebooks)
|
| 1241 |
+
coarse_output = coarse_output.view(coarse_output.shape[0], -1, coarse_generation_config.n_coarse_codebooks)
|
| 1242 |
+
|
| 1243 |
+
# brings ids into the range [0, codebook_size -1]
|
| 1244 |
+
coarse_output = torch.remainder(coarse_output - semantic_generation_config.semantic_vocab_size, codebook_size)
|
| 1245 |
+
batch_size = coarse_output.shape[0]
|
| 1246 |
+
|
| 1247 |
+
if history_prompt is not None:
|
| 1248 |
+
x_fine_history = torch.repeat_interleave(history_prompt["fine_prompt"].T[None], batch_size, dim=0)
|
| 1249 |
+
# transpose to get to shape (seq_len, n_fine_codebooks)
|
| 1250 |
+
else:
|
| 1251 |
+
x_fine_history = None
|
| 1252 |
+
|
| 1253 |
+
n_coarse = coarse_generation_config.n_coarse_codebooks
|
| 1254 |
+
|
| 1255 |
+
# pad the last 6th codebooks
|
| 1256 |
+
fine_input = F.pad(
|
| 1257 |
+
coarse_output,
|
| 1258 |
+
(0, fine_generation_config.n_fine_codebooks - n_coarse),
|
| 1259 |
+
"constant",
|
| 1260 |
+
codebook_size,
|
| 1261 |
+
)
|
| 1262 |
+
|
| 1263 |
+
# prepend history if available (max max_fine_history_length)
|
| 1264 |
+
if x_fine_history is not None:
|
| 1265 |
+
fine_input = torch.cat([x_fine_history[:, -max_fine_history_length:, :], fine_input], dim=1)
|
| 1266 |
+
|
| 1267 |
+
# len of the fine_history that has been added to fine_input
|
| 1268 |
+
n_history = x_fine_history[:, -max_fine_history_length:, :].shape[1]
|
| 1269 |
+
else:
|
| 1270 |
+
n_history = 0
|
| 1271 |
+
|
| 1272 |
+
n_remove_from_end = 0
|
| 1273 |
+
# need to pad if too short (since non-causal model)
|
| 1274 |
+
if fine_input.shape[1] < max_fine_input_length:
|
| 1275 |
+
n_remove_from_end = max_fine_input_length - fine_input.shape[1]
|
| 1276 |
+
fine_input = F.pad(fine_input, (0, 0, 0, n_remove_from_end), mode="constant", value=codebook_size)
|
| 1277 |
+
|
| 1278 |
+
# we can be lazy about fractional loop and just keep overwriting codebooks.
|
| 1279 |
+
# seems that coarse_output.shape[1] - (max_fine_input_length - n_history) is equal to minus n_remove_from_end
|
| 1280 |
+
# So if we needed to pad because too short, n_loops is always 1 (because n_remove_from_end > 0)
|
| 1281 |
+
# If not, we loop over at least twice.
|
| 1282 |
+
|
| 1283 |
+
n_loops = (coarse_output.shape[1] - (max_fine_input_length - n_history)) / max_fine_history_length
|
| 1284 |
+
n_loops = int(np.ceil(n_loops))
|
| 1285 |
+
n_loops = max(0, n_loops) + 1
|
| 1286 |
+
|
| 1287 |
+
for n_outer in range(n_loops):
|
| 1288 |
+
start_idx = min([n_outer * max_fine_history_length, fine_input.shape[1] - max_fine_input_length])
|
| 1289 |
+
|
| 1290 |
+
start_fill_idx = min(
|
| 1291 |
+
[n_history + n_outer * max_fine_history_length, fine_input.shape[1] - max_fine_history_length]
|
| 1292 |
+
)
|
| 1293 |
+
rel_start_fill_idx = start_fill_idx - start_idx
|
| 1294 |
+
input_buffer = fine_input[:, start_idx : start_idx + max_fine_input_length, :]
|
| 1295 |
+
for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks):
|
| 1296 |
+
logits = self.forward(n_inner, input_buffer).logits
|
| 1297 |
+
if temperature is None or temperature == 1.0:
|
| 1298 |
+
relevant_logits = logits[:, rel_start_fill_idx:, :codebook_size]
|
| 1299 |
+
codebook_preds = torch.argmax(relevant_logits, -1)
|
| 1300 |
+
else:
|
| 1301 |
+
relevant_logits = logits[:, :, :codebook_size] / temperature
|
| 1302 |
+
# apply softmax
|
| 1303 |
+
probs = F.softmax(relevant_logits, dim=-1)[:, rel_start_fill_idx:max_fine_input_length]
|
| 1304 |
+
# reshape to 2D: (batch_size, seq_len, codebook_size) -> (batch_size*seq_len, codebook_size)
|
| 1305 |
+
probs = probs.reshape((-1, codebook_size))
|
| 1306 |
+
# multinomial then reshape : (batch_size*seq_len)-> (batch_size,seq_len)
|
| 1307 |
+
codebook_preds = torch.multinomial(probs, num_samples=1).view(batch_size, -1)
|
| 1308 |
+
codebook_preds = codebook_preds.to(torch.int32)
|
| 1309 |
+
input_buffer[:, rel_start_fill_idx:, n_inner] = codebook_preds
|
| 1310 |
+
del logits, codebook_preds
|
| 1311 |
+
|
| 1312 |
+
# transfer into fine_input
|
| 1313 |
+
for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks):
|
| 1314 |
+
fine_input[
|
| 1315 |
+
:, start_fill_idx : start_fill_idx + (max_fine_input_length - rel_start_fill_idx), n_inner
|
| 1316 |
+
] = input_buffer[:, rel_start_fill_idx:, n_inner]
|
| 1317 |
+
del input_buffer
|
| 1318 |
+
|
| 1319 |
+
fine_input = fine_input.transpose(1, 2)[:, :, n_history:]
|
| 1320 |
+
if n_remove_from_end > 0:
|
| 1321 |
+
fine_input = fine_input[:, :, :-n_remove_from_end]
|
| 1322 |
+
|
| 1323 |
+
if fine_input.shape[-1] != coarse_output.shape[-2]:
|
| 1324 |
+
raise ValueError("input and output should have the same seq_len")
|
| 1325 |
+
|
| 1326 |
+
return fine_input
|
| 1327 |
+
|
| 1328 |
+
|
| 1329 |
+
@auto_docstring(
|
| 1330 |
+
custom_intro="""
|
| 1331 |
+
The full Bark model, a text-to-speech model composed of 4 sub-models:
|
| 1332 |
+
- [`BarkSemanticModel`] (also referred to as the 'text' model): a causal auto-regressive transformer model that
|
| 1333 |
+
takes
|
| 1334 |
+
as input tokenized text, and predicts semantic text tokens that capture the meaning of the text.
|
| 1335 |
+
- [`BarkCoarseModel`] (also referred to as the 'coarse acoustics' model), also a causal autoregressive transformer,
|
| 1336 |
+
that takes into input the results of the last model. It aims at regressing the first two audio codebooks necessary
|
| 1337 |
+
to `encodec`.
|
| 1338 |
+
- [`BarkFineModel`] (the 'fine acoustics' model), this time a non-causal autoencoder transformer, which iteratively
|
| 1339 |
+
predicts the last codebooks based on the sum of the previous codebooks embeddings.
|
| 1340 |
+
- having predicted all the codebook channels from the [`EncodecModel`], Bark uses it to decode the output audio
|
| 1341 |
+
array.
|
| 1342 |
+
|
| 1343 |
+
It should be noted that each of the first three modules can support conditional speaker embeddings to condition the
|
| 1344 |
+
output sound according to specific predefined voice.
|
| 1345 |
+
"""
|
| 1346 |
+
)
|
| 1347 |
+
class BarkModel(BarkPreTrainedModel):
|
| 1348 |
+
config: BarkConfig
|
| 1349 |
+
|
| 1350 |
+
def __init__(self, config):
|
| 1351 |
+
super().__init__(config)
|
| 1352 |
+
|
| 1353 |
+
self.semantic = BarkSemanticModel(config.semantic_config)
|
| 1354 |
+
self.coarse_acoustics = BarkCoarseModel(config.coarse_acoustics_config)
|
| 1355 |
+
self.fine_acoustics = BarkFineModel(config.fine_acoustics_config)
|
| 1356 |
+
|
| 1357 |
+
self.codec_model = AutoModel.from_config(config.codec_config)
|
| 1358 |
+
|
| 1359 |
+
self.config = config
|
| 1360 |
+
|
| 1361 |
+
@classmethod
|
| 1362 |
+
def can_generate(cls) -> bool:
|
| 1363 |
+
# Bark has a unique model structure, where the external class (`BarkModel`) doesn't need to inherit from
|
| 1364 |
+
# `GenerationMixin` (it has a non-standard generation method), but one of the internal models do
|
| 1365 |
+
# (`BarkSemanticModel`). This means that the base `can_generate()` will return `False`, but we need to
|
| 1366 |
+
# override it so as to do `GenerationConfig` handling in multiple parts of the codebase.
|
| 1367 |
+
return True
|
| 1368 |
+
|
| 1369 |
+
@property
|
| 1370 |
+
def device(self) -> torch.device:
|
| 1371 |
+
"""
|
| 1372 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
| 1373 |
+
device).
|
| 1374 |
+
"""
|
| 1375 |
+
# for bark_model, device must be verified on its sub-models
|
| 1376 |
+
# if has _hf_hook, has been offloaded so the device has to be found in the hook
|
| 1377 |
+
if not hasattr(self.semantic, "_hf_hook"):
|
| 1378 |
+
return get_parameter_device(self)
|
| 1379 |
+
for module in self.semantic.modules():
|
| 1380 |
+
if (
|
| 1381 |
+
hasattr(module, "_hf_hook")
|
| 1382 |
+
and hasattr(module._hf_hook, "execution_device")
|
| 1383 |
+
and module._hf_hook.execution_device is not None
|
| 1384 |
+
):
|
| 1385 |
+
return torch.device(module._hf_hook.execution_device)
|
| 1386 |
+
|
| 1387 |
+
def enable_cpu_offload(
|
| 1388 |
+
self,
|
| 1389 |
+
accelerator_id: Optional[int] = 0,
|
| 1390 |
+
**kwargs,
|
| 1391 |
+
):
|
| 1392 |
+
r"""
|
| 1393 |
+
Offloads all sub-models to CPU using accelerate, reducing memory usage with a low impact on performance. This
|
| 1394 |
+
method moves one whole sub-model at a time to the accelerator when it is used, and the sub-model remains in accelerator until the next sub-model runs.
|
| 1395 |
+
|
| 1396 |
+
Args:
|
| 1397 |
+
accelerator_id (`int`, *optional*, defaults to 0):
|
| 1398 |
+
accelerator id on which the sub-models will be loaded and offloaded. This argument is deprecated.
|
| 1399 |
+
kwargs (`dict`, *optional*):
|
| 1400 |
+
additional keyword arguments:
|
| 1401 |
+
`gpu_id`: accelerator id on which the sub-models will be loaded and offloaded.
|
| 1402 |
+
"""
|
| 1403 |
+
if is_accelerate_available():
|
| 1404 |
+
from accelerate import cpu_offload_with_hook
|
| 1405 |
+
else:
|
| 1406 |
+
raise ImportError("`enable_model_cpu_offload` requires `accelerate`.")
|
| 1407 |
+
|
| 1408 |
+
gpu_id = kwargs.get("gpu_id", 0)
|
| 1409 |
+
|
| 1410 |
+
if gpu_id != 0:
|
| 1411 |
+
warnings.warn(
|
| 1412 |
+
"The argument `gpu_id` is deprecated and will be removed in version 4.54.0 of Transformers. Please use `accelerator_id` instead.",
|
| 1413 |
+
FutureWarning,
|
| 1414 |
+
)
|
| 1415 |
+
accelerator_id = gpu_id
|
| 1416 |
+
|
| 1417 |
+
device_type = "cuda"
|
| 1418 |
+
if is_torch_accelerator_available():
|
| 1419 |
+
device_type = torch.accelerator.current_accelerator().type
|
| 1420 |
+
device = torch.device(f"{device_type}:{accelerator_id}")
|
| 1421 |
+
|
| 1422 |
+
torch_accelerator_module = getattr(torch, device_type)
|
| 1423 |
+
if self.device.type != "cpu":
|
| 1424 |
+
self.to("cpu")
|
| 1425 |
+
torch_accelerator_module.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
| 1426 |
+
|
| 1427 |
+
# this layer is used outside the first forward pass of semantic so need to be loaded before semantic
|
| 1428 |
+
self.semantic.input_embeds_layer, _ = cpu_offload_with_hook(self.semantic.input_embeds_layer, device)
|
| 1429 |
+
|
| 1430 |
+
hook = None
|
| 1431 |
+
for cpu_offloaded_model in [
|
| 1432 |
+
self.semantic,
|
| 1433 |
+
self.coarse_acoustics,
|
| 1434 |
+
self.fine_acoustics,
|
| 1435 |
+
]:
|
| 1436 |
+
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
| 1437 |
+
|
| 1438 |
+
self.fine_acoustics_hook = hook
|
| 1439 |
+
|
| 1440 |
+
_, hook = cpu_offload_with_hook(self.codec_model, device, prev_module_hook=hook)
|
| 1441 |
+
|
| 1442 |
+
# We'll offload the last model manually.
|
| 1443 |
+
self.codec_model_hook = hook
|
| 1444 |
+
|
| 1445 |
+
def codec_decode(self, fine_output, output_lengths=None):
|
| 1446 |
+
"""Turn quantized audio codes into audio array using encodec."""
|
| 1447 |
+
|
| 1448 |
+
fine_output = fine_output.transpose(0, 1)
|
| 1449 |
+
emb = self.codec_model.quantizer.decode(fine_output)
|
| 1450 |
+
|
| 1451 |
+
if output_lengths is not None:
|
| 1452 |
+
# encodec uses LSTMs which behaves differently with appended padding
|
| 1453 |
+
# decoding with encodec takes around 0.1% of the total generation time
|
| 1454 |
+
# to keep generation quality, we break batching
|
| 1455 |
+
out = [sample[:, :l].unsqueeze(0) for (sample, l) in zip(emb, output_lengths)]
|
| 1456 |
+
audio_arr = [self.codec_model.decoder(sample).squeeze() for sample in out]
|
| 1457 |
+
else:
|
| 1458 |
+
out = self.codec_model.decoder(emb)
|
| 1459 |
+
audio_arr = out.squeeze(1) # squeeze the codebook dimension
|
| 1460 |
+
|
| 1461 |
+
return audio_arr
|
| 1462 |
+
|
| 1463 |
+
@torch.no_grad()
|
| 1464 |
+
def generate(
|
| 1465 |
+
self,
|
| 1466 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1467 |
+
history_prompt: Optional[dict[str, torch.Tensor]] = None,
|
| 1468 |
+
return_output_lengths: Optional[bool] = None,
|
| 1469 |
+
**kwargs,
|
| 1470 |
+
) -> torch.LongTensor:
|
| 1471 |
+
"""
|
| 1472 |
+
Generates audio from an input prompt and an additional optional `Bark` speaker prompt.
|
| 1473 |
+
|
| 1474 |
+
Args:
|
| 1475 |
+
input_ids (`Optional[torch.Tensor]` of shape (batch_size, seq_len), *optional*):
|
| 1476 |
+
Input ids. Will be truncated up to 256 tokens. Note that the output audios will be as long as the
|
| 1477 |
+
longest generation among the batch.
|
| 1478 |
+
history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
|
| 1479 |
+
Optional `Bark` speaker prompt. Note that for now, this model takes only one speaker prompt per batch.
|
| 1480 |
+
kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments are of two types:
|
| 1481 |
+
|
| 1482 |
+
- Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model.
|
| 1483 |
+
- With a *semantic_*, *coarse_*, *fine_* prefix, they will be input for the `generate` method of the
|
| 1484 |
+
semantic, coarse and fine respectively. It has the priority over the keywords without a prefix.
|
| 1485 |
+
|
| 1486 |
+
This means you can, for example, specify a generation strategy for all sub-models except one.
|
| 1487 |
+
return_output_lengths (`bool`, *optional*):
|
| 1488 |
+
Whether or not to return the waveform lengths. Useful when batching.
|
| 1489 |
+
Returns:
|
| 1490 |
+
By default:
|
| 1491 |
+
- **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
|
| 1492 |
+
When `return_output_lengths=True`:
|
| 1493 |
+
Returns a tuple made of:
|
| 1494 |
+
- **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
|
| 1495 |
+
- **output_lengths** (`torch.Tensor` of shape (batch_size)): The length of each waveform in the batch
|
| 1496 |
+
Example:
|
| 1497 |
+
|
| 1498 |
+
```python
|
| 1499 |
+
>>> from transformers import AutoProcessor, BarkModel
|
| 1500 |
+
|
| 1501 |
+
>>> processor = AutoProcessor.from_pretrained("suno/bark-small")
|
| 1502 |
+
>>> model = BarkModel.from_pretrained("suno/bark-small")
|
| 1503 |
+
|
| 1504 |
+
>>> # To add a voice preset, you can pass `voice_preset` to `BarkProcessor.__call__(...)`
|
| 1505 |
+
>>> voice_preset = "v2/en_speaker_6"
|
| 1506 |
+
|
| 1507 |
+
>>> inputs = processor("Hello, my dog is cute, I need him in my life", voice_preset=voice_preset)
|
| 1508 |
+
|
| 1509 |
+
>>> audio_array = model.generate(**inputs, semantic_max_new_tokens=100)
|
| 1510 |
+
>>> audio_array = audio_array.cpu().numpy().squeeze()
|
| 1511 |
+
```
|
| 1512 |
+
"""
|
| 1513 |
+
# TODO (joao):workaround until nested generation config is compatible with PreTrained Model
|
| 1514 |
+
# todo: dict
|
| 1515 |
+
semantic_generation_config = BarkSemanticGenerationConfig(**self.generation_config.semantic_config)
|
| 1516 |
+
coarse_generation_config = BarkCoarseGenerationConfig(**self.generation_config.coarse_acoustics_config)
|
| 1517 |
+
fine_generation_config = BarkFineGenerationConfig(**self.generation_config.fine_acoustics_config)
|
| 1518 |
+
|
| 1519 |
+
kwargs_semantic = {
|
| 1520 |
+
# if "attention_mask" is set, it should not be passed to CoarseModel and FineModel
|
| 1521 |
+
"attention_mask": kwargs.pop("attention_mask", None),
|
| 1522 |
+
"min_eos_p": kwargs.pop("min_eos_p", None),
|
| 1523 |
+
}
|
| 1524 |
+
kwargs_coarse = {}
|
| 1525 |
+
kwargs_fine = {}
|
| 1526 |
+
for key, value in kwargs.items():
|
| 1527 |
+
if key.startswith("semantic_"):
|
| 1528 |
+
key = key[len("semantic_") :]
|
| 1529 |
+
kwargs_semantic[key] = value
|
| 1530 |
+
elif key.startswith("coarse_"):
|
| 1531 |
+
key = key[len("coarse_") :]
|
| 1532 |
+
kwargs_coarse[key] = value
|
| 1533 |
+
elif key.startswith("fine_"):
|
| 1534 |
+
key = key[len("fine_") :]
|
| 1535 |
+
kwargs_fine[key] = value
|
| 1536 |
+
else:
|
| 1537 |
+
# If the key is already in a specific config, then it's been set with a
|
| 1538 |
+
# submodules specific value and we don't override
|
| 1539 |
+
if key not in kwargs_semantic:
|
| 1540 |
+
kwargs_semantic[key] = value
|
| 1541 |
+
if key not in kwargs_coarse:
|
| 1542 |
+
kwargs_coarse[key] = value
|
| 1543 |
+
if key not in kwargs_fine:
|
| 1544 |
+
kwargs_fine[key] = value
|
| 1545 |
+
|
| 1546 |
+
# 1. Generate from the semantic model
|
| 1547 |
+
if "generation_config" in kwargs_semantic:
|
| 1548 |
+
kwargs_semantic.pop("generation_config")
|
| 1549 |
+
semantic_output = self.semantic.generate(
|
| 1550 |
+
input_ids,
|
| 1551 |
+
history_prompt=history_prompt,
|
| 1552 |
+
semantic_generation_config=semantic_generation_config,
|
| 1553 |
+
**kwargs_semantic,
|
| 1554 |
+
)
|
| 1555 |
+
|
| 1556 |
+
# 2. Generate from the coarse model
|
| 1557 |
+
if "generation_config" in kwargs_coarse:
|
| 1558 |
+
kwargs_coarse.pop("generation_config")
|
| 1559 |
+
coarse_output = self.coarse_acoustics.generate(
|
| 1560 |
+
semantic_output,
|
| 1561 |
+
history_prompt=history_prompt,
|
| 1562 |
+
semantic_generation_config=semantic_generation_config,
|
| 1563 |
+
coarse_generation_config=coarse_generation_config,
|
| 1564 |
+
codebook_size=self.generation_config.codebook_size,
|
| 1565 |
+
return_output_lengths=return_output_lengths,
|
| 1566 |
+
**kwargs_coarse,
|
| 1567 |
+
)
|
| 1568 |
+
|
| 1569 |
+
output_lengths = None
|
| 1570 |
+
if return_output_lengths:
|
| 1571 |
+
coarse_output, output_lengths = coarse_output
|
| 1572 |
+
# (batch_size, seq_len*coarse_codebooks) -> (batch_size, seq_len)
|
| 1573 |
+
output_lengths = output_lengths // coarse_generation_config.n_coarse_codebooks
|
| 1574 |
+
|
| 1575 |
+
# 3. "generate" from the fine model
|
| 1576 |
+
if "generation_config" in kwargs_fine:
|
| 1577 |
+
kwargs_fine.pop("generation_config")
|
| 1578 |
+
output = self.fine_acoustics.generate(
|
| 1579 |
+
coarse_output,
|
| 1580 |
+
history_prompt=history_prompt,
|
| 1581 |
+
semantic_generation_config=semantic_generation_config,
|
| 1582 |
+
coarse_generation_config=coarse_generation_config,
|
| 1583 |
+
fine_generation_config=fine_generation_config,
|
| 1584 |
+
codebook_size=self.generation_config.codebook_size,
|
| 1585 |
+
**kwargs_fine,
|
| 1586 |
+
)
|
| 1587 |
+
|
| 1588 |
+
if getattr(self, "fine_acoustics_hook", None) is not None:
|
| 1589 |
+
# Manually offload fine_acoustics to CPU
|
| 1590 |
+
# and load codec_model to GPU
|
| 1591 |
+
# since bark doesn't use codec_model forward pass
|
| 1592 |
+
self.fine_acoustics_hook.offload()
|
| 1593 |
+
self.codec_model = self.codec_model.to(self.device)
|
| 1594 |
+
|
| 1595 |
+
# 4. Decode the output and generate audio array
|
| 1596 |
+
audio = self.codec_decode(output, output_lengths)
|
| 1597 |
+
|
| 1598 |
+
if getattr(self, "codec_model_hook", None) is not None:
|
| 1599 |
+
# Offload codec_model to CPU
|
| 1600 |
+
self.codec_model_hook.offload()
|
| 1601 |
+
|
| 1602 |
+
if return_output_lengths:
|
| 1603 |
+
output_lengths = [len(sample) for sample in audio]
|
| 1604 |
+
audio = nn.utils.rnn.pad_sequence(audio, batch_first=True, padding_value=0)
|
| 1605 |
+
return audio, output_lengths
|
| 1606 |
+
|
| 1607 |
+
return audio
|
| 1608 |
+
|
| 1609 |
+
def tie_weights(self):
|
| 1610 |
+
"""
|
| 1611 |
+
Tie the weights between the input embeddings list and the output embeddings list.
|
| 1612 |
+
|
| 1613 |
+
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
|
| 1614 |
+
weights instead.
|
| 1615 |
+
"""
|
| 1616 |
+
for module in self.modules():
|
| 1617 |
+
if hasattr(module, "_tie_weights"):
|
| 1618 |
+
module._tie_weights()
|
| 1619 |
+
|
| 1620 |
+
|
| 1621 |
+
__all__ = [
|
| 1622 |
+
"BarkFineModel",
|
| 1623 |
+
"BarkSemanticModel",
|
| 1624 |
+
"BarkCoarseModel",
|
| 1625 |
+
"BarkModel",
|
| 1626 |
+
"BarkPreTrainedModel",
|
| 1627 |
+
"BarkCausalModel",
|
| 1628 |
+
]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bark/processing_bark.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Processor class for Bark
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
from ...feature_extraction_utils import BatchFeature
|
| 26 |
+
from ...processing_utils import ProcessorMixin
|
| 27 |
+
from ...tokenization_utils_base import BatchEncoding
|
| 28 |
+
from ...utils import logging
|
| 29 |
+
from ...utils.hub import cached_file
|
| 30 |
+
from ..auto import AutoTokenizer
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class BarkProcessor(ProcessorMixin):
|
| 37 |
+
r"""
|
| 38 |
+
Constructs a Bark processor which wraps a text tokenizer and optional Bark voice presets into a single processor.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
tokenizer ([`PreTrainedTokenizer`]):
|
| 42 |
+
An instance of [`PreTrainedTokenizer`].
|
| 43 |
+
speaker_embeddings (`dict[dict[str]]`, *optional*):
|
| 44 |
+
Optional nested speaker embeddings dictionary. The first level contains voice preset names (e.g
|
| 45 |
+
`"en_speaker_4"`). The second level contains `"semantic_prompt"`, `"coarse_prompt"` and `"fine_prompt"`
|
| 46 |
+
embeddings. The values correspond to the path of the corresponding `np.ndarray`. See
|
| 47 |
+
[here](https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c) for
|
| 48 |
+
a list of `voice_preset_names`.
|
| 49 |
+
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
tokenizer_class = "AutoTokenizer"
|
| 53 |
+
attributes = ["tokenizer"]
|
| 54 |
+
|
| 55 |
+
preset_shape = {
|
| 56 |
+
"semantic_prompt": 1, # 1D array of shape (X,)
|
| 57 |
+
"coarse_prompt": 2, # 2D array of shape (2,X)
|
| 58 |
+
"fine_prompt": 2, # 2D array of shape (8,X)
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
def __init__(self, tokenizer, speaker_embeddings=None):
|
| 62 |
+
super().__init__(tokenizer)
|
| 63 |
+
|
| 64 |
+
self.speaker_embeddings = speaker_embeddings
|
| 65 |
+
|
| 66 |
+
@classmethod
|
| 67 |
+
def from_pretrained(
|
| 68 |
+
cls, pretrained_processor_name_or_path, speaker_embeddings_dict_path="speaker_embeddings_path.json", **kwargs
|
| 69 |
+
):
|
| 70 |
+
r"""
|
| 71 |
+
Instantiate a Bark processor associated with a pretrained model.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 75 |
+
This can be either:
|
| 76 |
+
|
| 77 |
+
- a string, the *model id* of a pretrained [`BarkProcessor`] hosted inside a model repo on
|
| 78 |
+
huggingface.co.
|
| 79 |
+
- a path to a *directory* containing a processor saved using the [`~BarkProcessor.save_pretrained`]
|
| 80 |
+
method, e.g., `./my_model_directory/`.
|
| 81 |
+
speaker_embeddings_dict_path (`str`, *optional*, defaults to `"speaker_embeddings_path.json"`):
|
| 82 |
+
The name of the `.json` file containing the speaker_embeddings dictionary located in
|
| 83 |
+
`pretrained_model_name_or_path`. If `None`, no speaker_embeddings is loaded.
|
| 84 |
+
**kwargs
|
| 85 |
+
Additional keyword arguments passed along to both
|
| 86 |
+
[`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`].
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
if speaker_embeddings_dict_path is not None:
|
| 90 |
+
speaker_embeddings_path = cached_file(
|
| 91 |
+
pretrained_processor_name_or_path,
|
| 92 |
+
speaker_embeddings_dict_path,
|
| 93 |
+
subfolder=kwargs.pop("subfolder", None),
|
| 94 |
+
cache_dir=kwargs.pop("cache_dir", None),
|
| 95 |
+
force_download=kwargs.pop("force_download", False),
|
| 96 |
+
proxies=kwargs.pop("proxies", None),
|
| 97 |
+
resume_download=kwargs.pop("resume_download", None),
|
| 98 |
+
local_files_only=kwargs.pop("local_files_only", False),
|
| 99 |
+
token=kwargs.pop("use_auth_token", None),
|
| 100 |
+
revision=kwargs.pop("revision", None),
|
| 101 |
+
_raise_exceptions_for_gated_repo=False,
|
| 102 |
+
_raise_exceptions_for_missing_entries=False,
|
| 103 |
+
_raise_exceptions_for_connection_errors=False,
|
| 104 |
+
)
|
| 105 |
+
if speaker_embeddings_path is None:
|
| 106 |
+
logger.warning(
|
| 107 |
+
f"""`{os.path.join(pretrained_processor_name_or_path, speaker_embeddings_dict_path)}` does not exists
|
| 108 |
+
, no preloaded speaker embeddings will be used - Make sure to provide a correct path to the json
|
| 109 |
+
dictionary if wanted, otherwise set `speaker_embeddings_dict_path=None`."""
|
| 110 |
+
)
|
| 111 |
+
speaker_embeddings = None
|
| 112 |
+
else:
|
| 113 |
+
with open(speaker_embeddings_path) as speaker_embeddings_json:
|
| 114 |
+
speaker_embeddings = json.load(speaker_embeddings_json)
|
| 115 |
+
else:
|
| 116 |
+
speaker_embeddings = None
|
| 117 |
+
|
| 118 |
+
if speaker_embeddings is not None:
|
| 119 |
+
if "repo_or_path" in speaker_embeddings:
|
| 120 |
+
speaker_embeddings["repo_or_path"] = pretrained_processor_name_or_path
|
| 121 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_processor_name_or_path, **kwargs)
|
| 122 |
+
|
| 123 |
+
return cls(tokenizer=tokenizer, speaker_embeddings=speaker_embeddings)
|
| 124 |
+
|
| 125 |
+
def save_pretrained(
|
| 126 |
+
self,
|
| 127 |
+
save_directory,
|
| 128 |
+
speaker_embeddings_dict_path="speaker_embeddings_path.json",
|
| 129 |
+
speaker_embeddings_directory="speaker_embeddings",
|
| 130 |
+
push_to_hub: bool = False,
|
| 131 |
+
**kwargs,
|
| 132 |
+
):
|
| 133 |
+
"""
|
| 134 |
+
Saves the attributes of this processor (tokenizer...) in the specified directory so that it can be reloaded
|
| 135 |
+
using the [`~BarkProcessor.from_pretrained`] method.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
save_directory (`str` or `os.PathLike`):
|
| 139 |
+
Directory where the tokenizer files and the speaker embeddings will be saved (directory will be created
|
| 140 |
+
if it does not exist).
|
| 141 |
+
speaker_embeddings_dict_path (`str`, *optional*, defaults to `"speaker_embeddings_path.json"`):
|
| 142 |
+
The name of the `.json` file that will contains the speaker_embeddings nested path dictionary, if it
|
| 143 |
+
exists, and that will be located in `pretrained_model_name_or_path/speaker_embeddings_directory`.
|
| 144 |
+
speaker_embeddings_directory (`str`, *optional*, defaults to `"speaker_embeddings/"`):
|
| 145 |
+
The name of the folder in which the speaker_embeddings arrays will be saved.
|
| 146 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 147 |
+
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
| 148 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
| 149 |
+
namespace).
|
| 150 |
+
kwargs:
|
| 151 |
+
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
| 152 |
+
"""
|
| 153 |
+
if self.speaker_embeddings is not None:
|
| 154 |
+
os.makedirs(os.path.join(save_directory, speaker_embeddings_directory, "v2"), exist_ok=True)
|
| 155 |
+
|
| 156 |
+
embeddings_dict = {}
|
| 157 |
+
|
| 158 |
+
embeddings_dict["repo_or_path"] = save_directory
|
| 159 |
+
|
| 160 |
+
for prompt_key in self.available_voice_presets:
|
| 161 |
+
voice_preset = self._load_voice_preset(prompt_key)
|
| 162 |
+
|
| 163 |
+
tmp_dict = {}
|
| 164 |
+
for key in self.speaker_embeddings[prompt_key]:
|
| 165 |
+
np.save(
|
| 166 |
+
os.path.join(
|
| 167 |
+
embeddings_dict["repo_or_path"], speaker_embeddings_directory, f"{prompt_key}_{key}"
|
| 168 |
+
),
|
| 169 |
+
voice_preset[key],
|
| 170 |
+
allow_pickle=False,
|
| 171 |
+
)
|
| 172 |
+
tmp_dict[key] = os.path.join(speaker_embeddings_directory, f"{prompt_key}_{key}.npy")
|
| 173 |
+
|
| 174 |
+
embeddings_dict[prompt_key] = tmp_dict
|
| 175 |
+
|
| 176 |
+
with open(os.path.join(save_directory, speaker_embeddings_dict_path), "w") as fp:
|
| 177 |
+
json.dump(embeddings_dict, fp)
|
| 178 |
+
|
| 179 |
+
super().save_pretrained(save_directory, push_to_hub, **kwargs)
|
| 180 |
+
|
| 181 |
+
def _load_voice_preset(self, voice_preset: Optional[str] = None, **kwargs):
|
| 182 |
+
voice_preset_paths = self.speaker_embeddings[voice_preset]
|
| 183 |
+
|
| 184 |
+
voice_preset_dict = {}
|
| 185 |
+
for key in ["semantic_prompt", "coarse_prompt", "fine_prompt"]:
|
| 186 |
+
if key not in voice_preset_paths:
|
| 187 |
+
raise ValueError(
|
| 188 |
+
f"Voice preset unrecognized, missing {key} as a key in self.speaker_embeddings[{voice_preset}]."
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
path = cached_file(
|
| 192 |
+
self.speaker_embeddings.get("repo_or_path", "/"),
|
| 193 |
+
voice_preset_paths[key],
|
| 194 |
+
subfolder=kwargs.pop("subfolder", None),
|
| 195 |
+
cache_dir=kwargs.pop("cache_dir", None),
|
| 196 |
+
force_download=kwargs.pop("force_download", False),
|
| 197 |
+
proxies=kwargs.pop("proxies", None),
|
| 198 |
+
resume_download=kwargs.pop("resume_download", None),
|
| 199 |
+
local_files_only=kwargs.pop("local_files_only", False),
|
| 200 |
+
token=kwargs.pop("use_auth_token", None),
|
| 201 |
+
revision=kwargs.pop("revision", None),
|
| 202 |
+
_raise_exceptions_for_gated_repo=False,
|
| 203 |
+
_raise_exceptions_for_missing_entries=False,
|
| 204 |
+
_raise_exceptions_for_connection_errors=False,
|
| 205 |
+
)
|
| 206 |
+
if path is None:
|
| 207 |
+
raise ValueError(
|
| 208 |
+
f"""`{os.path.join(self.speaker_embeddings.get("repo_or_path", "/"), voice_preset_paths[key])}` does not exists
|
| 209 |
+
, no preloaded voice preset will be used - Make sure to provide correct paths to the {voice_preset}
|
| 210 |
+
embeddings."""
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
voice_preset_dict[key] = np.load(path)
|
| 214 |
+
|
| 215 |
+
return voice_preset_dict
|
| 216 |
+
|
| 217 |
+
def _validate_voice_preset_dict(self, voice_preset: Optional[dict] = None):
|
| 218 |
+
for key in ["semantic_prompt", "coarse_prompt", "fine_prompt"]:
|
| 219 |
+
if key not in voice_preset:
|
| 220 |
+
raise ValueError(f"Voice preset unrecognized, missing {key} as a key.")
|
| 221 |
+
|
| 222 |
+
if not isinstance(voice_preset[key], np.ndarray):
|
| 223 |
+
raise TypeError(f"{key} voice preset must be a {str(self.preset_shape[key])}D ndarray.")
|
| 224 |
+
|
| 225 |
+
if len(voice_preset[key].shape) != self.preset_shape[key]:
|
| 226 |
+
raise ValueError(f"{key} voice preset must be a {str(self.preset_shape[key])}D ndarray.")
|
| 227 |
+
|
| 228 |
+
@property
|
| 229 |
+
def available_voice_presets(self) -> list:
|
| 230 |
+
"""
|
| 231 |
+
Returns a list of available voice presets.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
`list[str]`: A list of voice preset names.
|
| 235 |
+
"""
|
| 236 |
+
if self.speaker_embeddings is None:
|
| 237 |
+
return []
|
| 238 |
+
|
| 239 |
+
voice_presets = list(self.speaker_embeddings.keys())
|
| 240 |
+
if "repo_or_path" in voice_presets:
|
| 241 |
+
voice_presets.remove("repo_or_path")
|
| 242 |
+
return voice_presets
|
| 243 |
+
|
| 244 |
+
def _verify_speaker_embeddings(self, remove_unavailable: bool = True):
|
| 245 |
+
# check which actually downloaded properly / are available
|
| 246 |
+
unavailable_keys = []
|
| 247 |
+
if self.speaker_embeddings is not None:
|
| 248 |
+
for voice_preset in self.available_voice_presets:
|
| 249 |
+
try:
|
| 250 |
+
voice_preset_dict = self._load_voice_preset(voice_preset)
|
| 251 |
+
except ValueError:
|
| 252 |
+
# error from `_load_voice_preset` of path not existing
|
| 253 |
+
unavailable_keys.append(voice_preset)
|
| 254 |
+
continue
|
| 255 |
+
self._validate_voice_preset_dict(voice_preset_dict)
|
| 256 |
+
|
| 257 |
+
if unavailable_keys:
|
| 258 |
+
logger.warning(
|
| 259 |
+
f"The following {len(unavailable_keys)} speaker embeddings are not available: {unavailable_keys} "
|
| 260 |
+
"If you would like to use them, please check the paths or try downloading them again."
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
if remove_unavailable:
|
| 264 |
+
for voice_preset in unavailable_keys:
|
| 265 |
+
del self.speaker_embeddings[voice_preset]
|
| 266 |
+
|
| 267 |
+
def __call__(
|
| 268 |
+
self,
|
| 269 |
+
text=None,
|
| 270 |
+
voice_preset=None,
|
| 271 |
+
return_tensors="pt",
|
| 272 |
+
max_length=256,
|
| 273 |
+
add_special_tokens=False,
|
| 274 |
+
return_attention_mask=True,
|
| 275 |
+
return_token_type_ids=False,
|
| 276 |
+
**kwargs,
|
| 277 |
+
) -> BatchEncoding:
|
| 278 |
+
"""
|
| 279 |
+
Main method to prepare for the model one or several sequences(s). This method forwards the `text` and `kwargs`
|
| 280 |
+
arguments to the AutoTokenizer's [`~AutoTokenizer.__call__`] to encode the text. The method also proposes a
|
| 281 |
+
voice preset which is a dictionary of arrays that conditions `Bark`'s output. `kwargs` arguments are forwarded
|
| 282 |
+
to the tokenizer and to `cached_file` method if `voice_preset` is a valid filename.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
text (`str`, `list[str]`, `list[list[str]]`):
|
| 286 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 287 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 288 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 289 |
+
voice_preset (`str`, `dict[np.ndarray]`):
|
| 290 |
+
The voice preset, i.e the speaker embeddings. It can either be a valid voice_preset name, e.g
|
| 291 |
+
`"en_speaker_1"`, or directly a dictionary of `np.ndarray` embeddings for each submodel of `Bark`. Or
|
| 292 |
+
it can be a valid file name of a local `.npz` single voice preset containing the keys
|
| 293 |
+
`"semantic_prompt"`, `"coarse_prompt"` and `"fine_prompt"`.
|
| 294 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 295 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
| 296 |
+
|
| 297 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 298 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
[`BatchEncoding`]: A [`BatchEncoding`] object containing the output of the `tokenizer`.
|
| 302 |
+
If a voice preset is provided, the returned object will include a `"history_prompt"` key
|
| 303 |
+
containing a [`BatchFeature`], i.e the voice preset with the right tensors type.
|
| 304 |
+
"""
|
| 305 |
+
if voice_preset is not None and not isinstance(voice_preset, dict):
|
| 306 |
+
if (
|
| 307 |
+
isinstance(voice_preset, str)
|
| 308 |
+
and self.speaker_embeddings is not None
|
| 309 |
+
and voice_preset in self.speaker_embeddings
|
| 310 |
+
):
|
| 311 |
+
voice_preset = self._load_voice_preset(voice_preset)
|
| 312 |
+
|
| 313 |
+
else:
|
| 314 |
+
if isinstance(voice_preset, str) and not voice_preset.endswith(".npz"):
|
| 315 |
+
voice_preset = voice_preset + ".npz"
|
| 316 |
+
|
| 317 |
+
voice_preset = np.load(voice_preset)
|
| 318 |
+
|
| 319 |
+
if voice_preset is not None:
|
| 320 |
+
self._validate_voice_preset_dict(voice_preset, **kwargs)
|
| 321 |
+
voice_preset = BatchFeature(data=voice_preset, tensor_type=return_tensors)
|
| 322 |
+
|
| 323 |
+
encoded_text = self.tokenizer(
|
| 324 |
+
text,
|
| 325 |
+
return_tensors=return_tensors,
|
| 326 |
+
padding="max_length",
|
| 327 |
+
max_length=max_length,
|
| 328 |
+
return_attention_mask=return_attention_mask,
|
| 329 |
+
return_token_type_ids=return_token_type_ids,
|
| 330 |
+
add_special_tokens=add_special_tokens,
|
| 331 |
+
**kwargs,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
if voice_preset is not None:
|
| 335 |
+
encoded_text["history_prompt"] = voice_preset
|
| 336 |
+
|
| 337 |
+
return encoded_text
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
__all__ = ["BarkProcessor"]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_bert import *
|
| 22 |
+
from .modeling_bert import *
|
| 23 |
+
from .modeling_flax_bert import *
|
| 24 |
+
from .modeling_tf_bert import *
|
| 25 |
+
from .tokenization_bert import *
|
| 26 |
+
from .tokenization_bert_fast import *
|
| 27 |
+
from .tokenization_bert_tf import *
|
| 28 |
+
else:
|
| 29 |
+
import sys
|
| 30 |
+
|
| 31 |
+
_file = globals()["__file__"]
|
| 32 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/configuration_bert.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""BERT model configuration"""
|
| 17 |
+
|
| 18 |
+
from collections import OrderedDict
|
| 19 |
+
from collections.abc import Mapping
|
| 20 |
+
|
| 21 |
+
from ...configuration_utils import PretrainedConfig
|
| 22 |
+
from ...onnx import OnnxConfig
|
| 23 |
+
from ...utils import logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class BertConfig(PretrainedConfig):
|
| 30 |
+
r"""
|
| 31 |
+
This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to
|
| 32 |
+
instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a
|
| 33 |
+
configuration with the defaults will yield a similar configuration to that of the BERT
|
| 34 |
+
[google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture.
|
| 35 |
+
|
| 36 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 37 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
| 42 |
+
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
|
| 43 |
+
`inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`].
|
| 44 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 45 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 46 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 47 |
+
Number of hidden layers in the Transformer encoder.
|
| 48 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 49 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 50 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 51 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
| 52 |
+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
| 53 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 54 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 55 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 56 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 57 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 58 |
+
The dropout ratio for the attention probabilities.
|
| 59 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 60 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 61 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 62 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
| 63 |
+
The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`].
|
| 64 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 65 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 66 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 67 |
+
The epsilon used by the layer normalization layers.
|
| 68 |
+
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
| 69 |
+
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
| 70 |
+
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
| 71 |
+
[Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
|
| 72 |
+
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
| 73 |
+
with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
|
| 74 |
+
is_decoder (`bool`, *optional*, defaults to `False`):
|
| 75 |
+
Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
|
| 76 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 77 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 78 |
+
relevant if `config.is_decoder=True`.
|
| 79 |
+
classifier_dropout (`float`, *optional*):
|
| 80 |
+
The dropout ratio for the classification head.
|
| 81 |
+
|
| 82 |
+
Examples:
|
| 83 |
+
|
| 84 |
+
```python
|
| 85 |
+
>>> from transformers import BertConfig, BertModel
|
| 86 |
+
|
| 87 |
+
>>> # Initializing a BERT google-bert/bert-base-uncased style configuration
|
| 88 |
+
>>> configuration = BertConfig()
|
| 89 |
+
|
| 90 |
+
>>> # Initializing a model (with random weights) from the google-bert/bert-base-uncased style configuration
|
| 91 |
+
>>> model = BertModel(configuration)
|
| 92 |
+
|
| 93 |
+
>>> # Accessing the model configuration
|
| 94 |
+
>>> configuration = model.config
|
| 95 |
+
```"""
|
| 96 |
+
|
| 97 |
+
model_type = "bert"
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
vocab_size=30522,
|
| 102 |
+
hidden_size=768,
|
| 103 |
+
num_hidden_layers=12,
|
| 104 |
+
num_attention_heads=12,
|
| 105 |
+
intermediate_size=3072,
|
| 106 |
+
hidden_act="gelu",
|
| 107 |
+
hidden_dropout_prob=0.1,
|
| 108 |
+
attention_probs_dropout_prob=0.1,
|
| 109 |
+
max_position_embeddings=512,
|
| 110 |
+
type_vocab_size=2,
|
| 111 |
+
initializer_range=0.02,
|
| 112 |
+
layer_norm_eps=1e-12,
|
| 113 |
+
pad_token_id=0,
|
| 114 |
+
position_embedding_type="absolute",
|
| 115 |
+
use_cache=True,
|
| 116 |
+
classifier_dropout=None,
|
| 117 |
+
**kwargs,
|
| 118 |
+
):
|
| 119 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 120 |
+
|
| 121 |
+
self.vocab_size = vocab_size
|
| 122 |
+
self.hidden_size = hidden_size
|
| 123 |
+
self.num_hidden_layers = num_hidden_layers
|
| 124 |
+
self.num_attention_heads = num_attention_heads
|
| 125 |
+
self.hidden_act = hidden_act
|
| 126 |
+
self.intermediate_size = intermediate_size
|
| 127 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 128 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 129 |
+
self.max_position_embeddings = max_position_embeddings
|
| 130 |
+
self.type_vocab_size = type_vocab_size
|
| 131 |
+
self.initializer_range = initializer_range
|
| 132 |
+
self.layer_norm_eps = layer_norm_eps
|
| 133 |
+
self.position_embedding_type = position_embedding_type
|
| 134 |
+
self.use_cache = use_cache
|
| 135 |
+
self.classifier_dropout = classifier_dropout
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class BertOnnxConfig(OnnxConfig):
|
| 139 |
+
@property
|
| 140 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 141 |
+
if self.task == "multiple-choice":
|
| 142 |
+
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
| 143 |
+
else:
|
| 144 |
+
dynamic_axis = {0: "batch", 1: "sequence"}
|
| 145 |
+
return OrderedDict(
|
| 146 |
+
[
|
| 147 |
+
("input_ids", dynamic_axis),
|
| 148 |
+
("attention_mask", dynamic_axis),
|
| 149 |
+
("token_type_ids", dynamic_axis),
|
| 150 |
+
]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
__all__ = ["BertConfig", "BertOnnxConfig"]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_bert.py
ADDED
|
@@ -0,0 +1,1801 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""PyTorch BERT model."""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
import os
|
| 20 |
+
import warnings
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from typing import Optional, Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from torch import nn
|
| 26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 27 |
+
|
| 28 |
+
from ...activations import ACT2FN
|
| 29 |
+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
| 30 |
+
from ...generation import GenerationMixin
|
| 31 |
+
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
|
| 32 |
+
from ...modeling_layers import GradientCheckpointingLayer
|
| 33 |
+
from ...modeling_outputs import (
|
| 34 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 35 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 36 |
+
CausalLMOutputWithCrossAttentions,
|
| 37 |
+
MaskedLMOutput,
|
| 38 |
+
MultipleChoiceModelOutput,
|
| 39 |
+
NextSentencePredictorOutput,
|
| 40 |
+
QuestionAnsweringModelOutput,
|
| 41 |
+
SequenceClassifierOutput,
|
| 42 |
+
TokenClassifierOutput,
|
| 43 |
+
)
|
| 44 |
+
from ...modeling_utils import PreTrainedModel
|
| 45 |
+
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
| 46 |
+
from ...utils import ModelOutput, auto_docstring, logging
|
| 47 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 48 |
+
from .configuration_bert import BertConfig
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
logger = logging.get_logger(__name__)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
| 55 |
+
"""Load tf checkpoints in a pytorch model."""
|
| 56 |
+
try:
|
| 57 |
+
import re
|
| 58 |
+
|
| 59 |
+
import numpy as np
|
| 60 |
+
import tensorflow as tf
|
| 61 |
+
except ImportError:
|
| 62 |
+
logger.error(
|
| 63 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
| 64 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
| 65 |
+
)
|
| 66 |
+
raise
|
| 67 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
| 68 |
+
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
| 69 |
+
# Load weights from TF model
|
| 70 |
+
init_vars = tf.train.list_variables(tf_path)
|
| 71 |
+
names = []
|
| 72 |
+
arrays = []
|
| 73 |
+
for name, shape in init_vars:
|
| 74 |
+
logger.info(f"Loading TF weight {name} with shape {shape}")
|
| 75 |
+
array = tf.train.load_variable(tf_path, name)
|
| 76 |
+
names.append(name)
|
| 77 |
+
arrays.append(array)
|
| 78 |
+
|
| 79 |
+
for name, array in zip(names, arrays):
|
| 80 |
+
name = name.split("/")
|
| 81 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
| 82 |
+
# which are not required for using pretrained model
|
| 83 |
+
if any(
|
| 84 |
+
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
| 85 |
+
for n in name
|
| 86 |
+
):
|
| 87 |
+
logger.info(f"Skipping {'/'.join(name)}")
|
| 88 |
+
continue
|
| 89 |
+
pointer = model
|
| 90 |
+
for m_name in name:
|
| 91 |
+
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
| 92 |
+
scope_names = re.split(r"_(\d+)", m_name)
|
| 93 |
+
else:
|
| 94 |
+
scope_names = [m_name]
|
| 95 |
+
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
| 96 |
+
pointer = getattr(pointer, "weight")
|
| 97 |
+
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
| 98 |
+
pointer = getattr(pointer, "bias")
|
| 99 |
+
elif scope_names[0] == "output_weights":
|
| 100 |
+
pointer = getattr(pointer, "weight")
|
| 101 |
+
elif scope_names[0] == "squad":
|
| 102 |
+
pointer = getattr(pointer, "classifier")
|
| 103 |
+
else:
|
| 104 |
+
try:
|
| 105 |
+
pointer = getattr(pointer, scope_names[0])
|
| 106 |
+
except AttributeError:
|
| 107 |
+
logger.info(f"Skipping {'/'.join(name)}")
|
| 108 |
+
continue
|
| 109 |
+
if len(scope_names) >= 2:
|
| 110 |
+
num = int(scope_names[1])
|
| 111 |
+
pointer = pointer[num]
|
| 112 |
+
if m_name[-11:] == "_embeddings":
|
| 113 |
+
pointer = getattr(pointer, "weight")
|
| 114 |
+
elif m_name == "kernel":
|
| 115 |
+
array = np.transpose(array)
|
| 116 |
+
try:
|
| 117 |
+
if pointer.shape != array.shape:
|
| 118 |
+
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
| 119 |
+
except ValueError as e:
|
| 120 |
+
e.args += (pointer.shape, array.shape)
|
| 121 |
+
raise
|
| 122 |
+
logger.info(f"Initialize PyTorch weight {name}")
|
| 123 |
+
pointer.data = torch.from_numpy(array)
|
| 124 |
+
return model
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class BertEmbeddings(nn.Module):
|
| 128 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 129 |
+
|
| 130 |
+
def __init__(self, config):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 133 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 134 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
| 135 |
+
|
| 136 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 137 |
+
# any TensorFlow checkpoint file
|
| 138 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 139 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 140 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 141 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 142 |
+
self.register_buffer(
|
| 143 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 144 |
+
)
|
| 145 |
+
self.register_buffer(
|
| 146 |
+
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def forward(
|
| 150 |
+
self,
|
| 151 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 152 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 153 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 154 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 155 |
+
past_key_values_length: int = 0,
|
| 156 |
+
) -> torch.Tensor:
|
| 157 |
+
if input_ids is not None:
|
| 158 |
+
input_shape = input_ids.size()
|
| 159 |
+
else:
|
| 160 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 161 |
+
|
| 162 |
+
seq_length = input_shape[1]
|
| 163 |
+
|
| 164 |
+
if position_ids is None:
|
| 165 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
| 166 |
+
|
| 167 |
+
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
| 168 |
+
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
| 169 |
+
# issue #5664
|
| 170 |
+
if token_type_ids is None:
|
| 171 |
+
if hasattr(self, "token_type_ids"):
|
| 172 |
+
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
| 173 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
| 174 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 175 |
+
else:
|
| 176 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
| 177 |
+
|
| 178 |
+
if inputs_embeds is None:
|
| 179 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 180 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 181 |
+
|
| 182 |
+
embeddings = inputs_embeds + token_type_embeddings
|
| 183 |
+
if self.position_embedding_type == "absolute":
|
| 184 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 185 |
+
embeddings += position_embeddings
|
| 186 |
+
embeddings = self.LayerNorm(embeddings)
|
| 187 |
+
embeddings = self.dropout(embeddings)
|
| 188 |
+
return embeddings
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class BertSelfAttention(nn.Module):
|
| 192 |
+
def __init__(self, config, position_embedding_type=None, layer_idx=None):
|
| 193 |
+
super().__init__()
|
| 194 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 195 |
+
raise ValueError(
|
| 196 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 197 |
+
f"heads ({config.num_attention_heads})"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
self.num_attention_heads = config.num_attention_heads
|
| 201 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 202 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 203 |
+
|
| 204 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 205 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 206 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 207 |
+
|
| 208 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 209 |
+
self.position_embedding_type = position_embedding_type or getattr(
|
| 210 |
+
config, "position_embedding_type", "absolute"
|
| 211 |
+
)
|
| 212 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 213 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 214 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 215 |
+
|
| 216 |
+
self.is_decoder = config.is_decoder
|
| 217 |
+
self.layer_idx = layer_idx
|
| 218 |
+
|
| 219 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 220 |
+
def forward(
|
| 221 |
+
self,
|
| 222 |
+
hidden_states: torch.Tensor,
|
| 223 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 224 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 225 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 226 |
+
past_key_values: Optional[Cache] = None,
|
| 227 |
+
output_attentions: Optional[bool] = False,
|
| 228 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 229 |
+
) -> tuple[torch.Tensor]:
|
| 230 |
+
batch_size, seq_length, _ = hidden_states.shape
|
| 231 |
+
query_layer = self.query(hidden_states)
|
| 232 |
+
query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
|
| 233 |
+
1, 2
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
is_updated = False
|
| 237 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 238 |
+
if past_key_values is not None:
|
| 239 |
+
if isinstance(past_key_values, EncoderDecoderCache):
|
| 240 |
+
is_updated = past_key_values.is_updated.get(self.layer_idx)
|
| 241 |
+
if is_cross_attention:
|
| 242 |
+
# after the first generated id, we can subsequently re-use all key/value_layer from cache
|
| 243 |
+
curr_past_key_value = past_key_values.cross_attention_cache
|
| 244 |
+
else:
|
| 245 |
+
curr_past_key_value = past_key_values.self_attention_cache
|
| 246 |
+
else:
|
| 247 |
+
curr_past_key_value = past_key_values
|
| 248 |
+
|
| 249 |
+
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
| 250 |
+
if is_cross_attention and past_key_values is not None and is_updated:
|
| 251 |
+
# reuse k,v, cross_attentions
|
| 252 |
+
key_layer = curr_past_key_value.layers[self.layer_idx].keys
|
| 253 |
+
value_layer = curr_past_key_value.layers[self.layer_idx].values
|
| 254 |
+
else:
|
| 255 |
+
key_layer = self.key(current_states)
|
| 256 |
+
key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
|
| 257 |
+
1, 2
|
| 258 |
+
)
|
| 259 |
+
value_layer = self.value(current_states)
|
| 260 |
+
value_layer = value_layer.view(
|
| 261 |
+
batch_size, -1, self.num_attention_heads, self.attention_head_size
|
| 262 |
+
).transpose(1, 2)
|
| 263 |
+
|
| 264 |
+
if past_key_values is not None:
|
| 265 |
+
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
|
| 266 |
+
cache_position = cache_position if not is_cross_attention else None
|
| 267 |
+
key_layer, value_layer = curr_past_key_value.update(
|
| 268 |
+
key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
|
| 269 |
+
)
|
| 270 |
+
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
| 271 |
+
if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
|
| 272 |
+
past_key_values.is_updated[self.layer_idx] = True
|
| 273 |
+
|
| 274 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 275 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 276 |
+
|
| 277 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 278 |
+
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
| 279 |
+
if past_key_values is not None:
|
| 280 |
+
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
| 281 |
+
-1, 1
|
| 282 |
+
)
|
| 283 |
+
else:
|
| 284 |
+
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 285 |
+
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 286 |
+
distance = position_ids_l - position_ids_r
|
| 287 |
+
|
| 288 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 289 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 290 |
+
|
| 291 |
+
if self.position_embedding_type == "relative_key":
|
| 292 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 293 |
+
attention_scores = attention_scores + relative_position_scores
|
| 294 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 295 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 296 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 297 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 298 |
+
|
| 299 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 300 |
+
if attention_mask is not None:
|
| 301 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 302 |
+
attention_scores = attention_scores + attention_mask
|
| 303 |
+
|
| 304 |
+
# Normalize the attention scores to probabilities.
|
| 305 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 306 |
+
|
| 307 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 308 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 309 |
+
attention_probs = self.dropout(attention_probs)
|
| 310 |
+
|
| 311 |
+
# Mask heads if we want to
|
| 312 |
+
if head_mask is not None:
|
| 313 |
+
attention_probs = attention_probs * head_mask
|
| 314 |
+
|
| 315 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 316 |
+
|
| 317 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 318 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 319 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
| 320 |
+
|
| 321 |
+
return context_layer, attention_probs
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
class BertSdpaSelfAttention(BertSelfAttention):
|
| 325 |
+
def __init__(self, config, position_embedding_type=None, layer_idx=None):
|
| 326 |
+
super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx)
|
| 327 |
+
self.dropout_prob = config.attention_probs_dropout_prob
|
| 328 |
+
|
| 329 |
+
# Adapted from BertSelfAttention
|
| 330 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 331 |
+
def forward(
|
| 332 |
+
self,
|
| 333 |
+
hidden_states: torch.Tensor,
|
| 334 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 335 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 336 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 337 |
+
past_key_values: Optional[Cache] = None,
|
| 338 |
+
output_attentions: Optional[bool] = False,
|
| 339 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 340 |
+
) -> tuple[torch.Tensor]:
|
| 341 |
+
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
|
| 342 |
+
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
|
| 343 |
+
logger.warning_once(
|
| 344 |
+
"BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
| 345 |
+
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
|
| 346 |
+
"the manual attention implementation, but specifying the manual implementation will be required from "
|
| 347 |
+
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
|
| 348 |
+
'`attn_implementation="eager"` when loading the model.'
|
| 349 |
+
)
|
| 350 |
+
return super().forward(
|
| 351 |
+
hidden_states,
|
| 352 |
+
attention_mask,
|
| 353 |
+
head_mask,
|
| 354 |
+
encoder_hidden_states,
|
| 355 |
+
past_key_values,
|
| 356 |
+
output_attentions,
|
| 357 |
+
cache_position,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
bsz, tgt_len, _ = hidden_states.size()
|
| 361 |
+
|
| 362 |
+
query_layer = (
|
| 363 |
+
self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
is_updated = False
|
| 367 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 368 |
+
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
| 369 |
+
if past_key_values is not None:
|
| 370 |
+
if isinstance(past_key_values, EncoderDecoderCache):
|
| 371 |
+
is_updated = past_key_values.is_updated.get(self.layer_idx)
|
| 372 |
+
if is_cross_attention:
|
| 373 |
+
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
| 374 |
+
curr_past_key_value = past_key_values.cross_attention_cache
|
| 375 |
+
else:
|
| 376 |
+
curr_past_key_value = past_key_values.self_attention_cache
|
| 377 |
+
else:
|
| 378 |
+
curr_past_key_value = past_key_values
|
| 379 |
+
|
| 380 |
+
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
| 381 |
+
if is_cross_attention and past_key_values is not None and is_updated:
|
| 382 |
+
# reuse k,v, cross_attentions
|
| 383 |
+
key_layer = curr_past_key_value.layers[self.layer_idx].keys
|
| 384 |
+
value_layer = curr_past_key_value.layers[self.layer_idx].values
|
| 385 |
+
else:
|
| 386 |
+
key_layer = (
|
| 387 |
+
self.key(current_states)
|
| 388 |
+
.view(bsz, -1, self.num_attention_heads, self.attention_head_size)
|
| 389 |
+
.transpose(1, 2)
|
| 390 |
+
)
|
| 391 |
+
value_layer = (
|
| 392 |
+
self.value(current_states)
|
| 393 |
+
.view(bsz, -1, self.num_attention_heads, self.attention_head_size)
|
| 394 |
+
.transpose(1, 2)
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
if past_key_values is not None:
|
| 398 |
+
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
|
| 399 |
+
cache_position = cache_position if not is_cross_attention else None
|
| 400 |
+
key_layer, value_layer = curr_past_key_value.update(
|
| 401 |
+
key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
|
| 402 |
+
)
|
| 403 |
+
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
| 404 |
+
if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
|
| 405 |
+
past_key_values.is_updated[self.layer_idx] = True
|
| 406 |
+
|
| 407 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 408 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 409 |
+
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
|
| 410 |
+
# a causal mask in case tgt_len == 1.
|
| 411 |
+
is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1
|
| 412 |
+
|
| 413 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 414 |
+
query_layer,
|
| 415 |
+
key_layer,
|
| 416 |
+
value_layer,
|
| 417 |
+
attn_mask=attention_mask,
|
| 418 |
+
dropout_p=self.dropout_prob if self.training else 0.0,
|
| 419 |
+
is_causal=is_causal,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
attn_output = attn_output.transpose(1, 2)
|
| 423 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
|
| 424 |
+
|
| 425 |
+
return attn_output, None
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
class BertSelfOutput(nn.Module):
|
| 429 |
+
def __init__(self, config):
|
| 430 |
+
super().__init__()
|
| 431 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 432 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 433 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 434 |
+
|
| 435 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 436 |
+
hidden_states = self.dense(hidden_states)
|
| 437 |
+
hidden_states = self.dropout(hidden_states)
|
| 438 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 439 |
+
return hidden_states
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
BERT_SELF_ATTENTION_CLASSES = {
|
| 443 |
+
"eager": BertSelfAttention,
|
| 444 |
+
"sdpa": BertSdpaSelfAttention,
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class BertAttention(nn.Module):
|
| 449 |
+
def __init__(self, config, position_embedding_type=None, layer_idx=None):
|
| 450 |
+
super().__init__()
|
| 451 |
+
self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
| 452 |
+
config,
|
| 453 |
+
position_embedding_type=position_embedding_type,
|
| 454 |
+
layer_idx=layer_idx,
|
| 455 |
+
)
|
| 456 |
+
self.output = BertSelfOutput(config)
|
| 457 |
+
self.pruned_heads = set()
|
| 458 |
+
|
| 459 |
+
def prune_heads(self, heads):
|
| 460 |
+
if len(heads) == 0:
|
| 461 |
+
return
|
| 462 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 463 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# Prune linear layers
|
| 467 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 468 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 469 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 470 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 471 |
+
|
| 472 |
+
# Update hyper params and store pruned heads
|
| 473 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 474 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 475 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 476 |
+
|
| 477 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 478 |
+
def forward(
|
| 479 |
+
self,
|
| 480 |
+
hidden_states: torch.Tensor,
|
| 481 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 482 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 483 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 484 |
+
past_key_values: Optional[Cache] = None,
|
| 485 |
+
output_attentions: Optional[bool] = False,
|
| 486 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 487 |
+
) -> tuple[torch.Tensor]:
|
| 488 |
+
self_outputs = self.self(
|
| 489 |
+
hidden_states,
|
| 490 |
+
attention_mask=attention_mask,
|
| 491 |
+
head_mask=head_mask,
|
| 492 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 493 |
+
past_key_values=past_key_values,
|
| 494 |
+
output_attentions=output_attentions,
|
| 495 |
+
cache_position=cache_position,
|
| 496 |
+
)
|
| 497 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 498 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 499 |
+
return outputs
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
class BertIntermediate(nn.Module):
|
| 503 |
+
def __init__(self, config):
|
| 504 |
+
super().__init__()
|
| 505 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 506 |
+
if isinstance(config.hidden_act, str):
|
| 507 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 508 |
+
else:
|
| 509 |
+
self.intermediate_act_fn = config.hidden_act
|
| 510 |
+
|
| 511 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 512 |
+
hidden_states = self.dense(hidden_states)
|
| 513 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 514 |
+
return hidden_states
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
class BertOutput(nn.Module):
|
| 518 |
+
def __init__(self, config):
|
| 519 |
+
super().__init__()
|
| 520 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 521 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 522 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 523 |
+
|
| 524 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 525 |
+
hidden_states = self.dense(hidden_states)
|
| 526 |
+
hidden_states = self.dropout(hidden_states)
|
| 527 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 528 |
+
return hidden_states
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
class BertLayer(GradientCheckpointingLayer):
|
| 532 |
+
def __init__(self, config, layer_idx=None):
|
| 533 |
+
super().__init__()
|
| 534 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 535 |
+
self.seq_len_dim = 1
|
| 536 |
+
self.attention = BertAttention(config, layer_idx=layer_idx)
|
| 537 |
+
self.is_decoder = config.is_decoder
|
| 538 |
+
self.add_cross_attention = config.add_cross_attention
|
| 539 |
+
if self.add_cross_attention:
|
| 540 |
+
if not self.is_decoder:
|
| 541 |
+
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
| 542 |
+
self.crossattention = BertAttention(config, position_embedding_type="absolute", layer_idx=layer_idx)
|
| 543 |
+
self.intermediate = BertIntermediate(config)
|
| 544 |
+
self.output = BertOutput(config)
|
| 545 |
+
|
| 546 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 547 |
+
def forward(
|
| 548 |
+
self,
|
| 549 |
+
hidden_states: torch.Tensor,
|
| 550 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 551 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 552 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 553 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 554 |
+
past_key_values: Optional[Cache] = None,
|
| 555 |
+
output_attentions: Optional[bool] = False,
|
| 556 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 557 |
+
) -> tuple[torch.Tensor]:
|
| 558 |
+
self_attention_outputs = self.attention(
|
| 559 |
+
hidden_states,
|
| 560 |
+
attention_mask=attention_mask,
|
| 561 |
+
head_mask=head_mask,
|
| 562 |
+
output_attentions=output_attentions,
|
| 563 |
+
past_key_values=past_key_values,
|
| 564 |
+
cache_position=cache_position,
|
| 565 |
+
)
|
| 566 |
+
attention_output = self_attention_outputs[0]
|
| 567 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 568 |
+
|
| 569 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
| 570 |
+
if not hasattr(self, "crossattention"):
|
| 571 |
+
raise ValueError(
|
| 572 |
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
| 573 |
+
" by setting `config.add_cross_attention=True`"
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
cross_attention_outputs = self.crossattention(
|
| 577 |
+
attention_output,
|
| 578 |
+
attention_mask=encoder_attention_mask,
|
| 579 |
+
head_mask=head_mask,
|
| 580 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 581 |
+
past_key_values=past_key_values,
|
| 582 |
+
output_attentions=output_attentions,
|
| 583 |
+
cache_position=cache_position,
|
| 584 |
+
)
|
| 585 |
+
attention_output = cross_attention_outputs[0]
|
| 586 |
+
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
| 587 |
+
|
| 588 |
+
layer_output = apply_chunking_to_forward(
|
| 589 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
| 590 |
+
)
|
| 591 |
+
outputs = (layer_output,) + outputs
|
| 592 |
+
|
| 593 |
+
return outputs
|
| 594 |
+
|
| 595 |
+
def feed_forward_chunk(self, attention_output):
|
| 596 |
+
intermediate_output = self.intermediate(attention_output)
|
| 597 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 598 |
+
return layer_output
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
class BertEncoder(nn.Module):
|
| 602 |
+
def __init__(self, config, layer_idx=None):
|
| 603 |
+
super().__init__()
|
| 604 |
+
self.config = config
|
| 605 |
+
self.layer = nn.ModuleList([BertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
| 606 |
+
self.gradient_checkpointing = False
|
| 607 |
+
|
| 608 |
+
def forward(
|
| 609 |
+
self,
|
| 610 |
+
hidden_states: torch.Tensor,
|
| 611 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 612 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 613 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 614 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 615 |
+
past_key_values: Optional[Cache] = None,
|
| 616 |
+
use_cache: Optional[bool] = None,
|
| 617 |
+
output_attentions: Optional[bool] = False,
|
| 618 |
+
output_hidden_states: Optional[bool] = False,
|
| 619 |
+
return_dict: Optional[bool] = True,
|
| 620 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 621 |
+
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
| 622 |
+
all_hidden_states = () if output_hidden_states else None
|
| 623 |
+
all_self_attentions = () if output_attentions else None
|
| 624 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 625 |
+
|
| 626 |
+
if self.gradient_checkpointing and self.training:
|
| 627 |
+
if use_cache:
|
| 628 |
+
logger.warning_once(
|
| 629 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 630 |
+
)
|
| 631 |
+
use_cache = False
|
| 632 |
+
|
| 633 |
+
if use_cache and self.config.is_decoder and past_key_values is None:
|
| 634 |
+
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
| 635 |
+
|
| 636 |
+
if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
|
| 637 |
+
logger.warning_once(
|
| 638 |
+
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
|
| 639 |
+
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
|
| 640 |
+
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
|
| 641 |
+
)
|
| 642 |
+
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
|
| 643 |
+
|
| 644 |
+
for i, layer_module in enumerate(self.layer):
|
| 645 |
+
if output_hidden_states:
|
| 646 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 647 |
+
|
| 648 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 649 |
+
|
| 650 |
+
layer_outputs = layer_module(
|
| 651 |
+
hidden_states,
|
| 652 |
+
attention_mask,
|
| 653 |
+
layer_head_mask,
|
| 654 |
+
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
| 655 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 656 |
+
past_key_values=past_key_values,
|
| 657 |
+
output_attentions=output_attentions,
|
| 658 |
+
cache_position=cache_position,
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
hidden_states = layer_outputs[0]
|
| 662 |
+
if output_attentions:
|
| 663 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 664 |
+
if self.config.add_cross_attention:
|
| 665 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
| 666 |
+
|
| 667 |
+
if output_hidden_states:
|
| 668 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 669 |
+
|
| 670 |
+
if not return_dict:
|
| 671 |
+
return tuple(
|
| 672 |
+
v
|
| 673 |
+
for v in [
|
| 674 |
+
hidden_states,
|
| 675 |
+
past_key_values,
|
| 676 |
+
all_hidden_states,
|
| 677 |
+
all_self_attentions,
|
| 678 |
+
all_cross_attentions,
|
| 679 |
+
]
|
| 680 |
+
if v is not None
|
| 681 |
+
)
|
| 682 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 683 |
+
last_hidden_state=hidden_states,
|
| 684 |
+
past_key_values=past_key_values,
|
| 685 |
+
hidden_states=all_hidden_states,
|
| 686 |
+
attentions=all_self_attentions,
|
| 687 |
+
cross_attentions=all_cross_attentions,
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
class BertPooler(nn.Module):
|
| 692 |
+
def __init__(self, config):
|
| 693 |
+
super().__init__()
|
| 694 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 695 |
+
self.activation = nn.Tanh()
|
| 696 |
+
|
| 697 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 698 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 699 |
+
# to the first token.
|
| 700 |
+
first_token_tensor = hidden_states[:, 0]
|
| 701 |
+
pooled_output = self.dense(first_token_tensor)
|
| 702 |
+
pooled_output = self.activation(pooled_output)
|
| 703 |
+
return pooled_output
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 707 |
+
def __init__(self, config):
|
| 708 |
+
super().__init__()
|
| 709 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 710 |
+
if isinstance(config.hidden_act, str):
|
| 711 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 712 |
+
else:
|
| 713 |
+
self.transform_act_fn = config.hidden_act
|
| 714 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 715 |
+
|
| 716 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 717 |
+
hidden_states = self.dense(hidden_states)
|
| 718 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 719 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 720 |
+
return hidden_states
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
class BertLMPredictionHead(nn.Module):
|
| 724 |
+
def __init__(self, config):
|
| 725 |
+
super().__init__()
|
| 726 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 727 |
+
|
| 728 |
+
# The output weights are the same as the input embeddings, but there is
|
| 729 |
+
# an output-only bias for each token.
|
| 730 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 731 |
+
|
| 732 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 733 |
+
|
| 734 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 735 |
+
self.decoder.bias = self.bias
|
| 736 |
+
|
| 737 |
+
def _tie_weights(self):
|
| 738 |
+
self.decoder.bias = self.bias
|
| 739 |
+
|
| 740 |
+
def forward(self, hidden_states):
|
| 741 |
+
hidden_states = self.transform(hidden_states)
|
| 742 |
+
hidden_states = self.decoder(hidden_states)
|
| 743 |
+
return hidden_states
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
class BertOnlyMLMHead(nn.Module):
|
| 747 |
+
def __init__(self, config):
|
| 748 |
+
super().__init__()
|
| 749 |
+
self.predictions = BertLMPredictionHead(config)
|
| 750 |
+
|
| 751 |
+
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
| 752 |
+
prediction_scores = self.predictions(sequence_output)
|
| 753 |
+
return prediction_scores
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
class BertOnlyNSPHead(nn.Module):
|
| 757 |
+
def __init__(self, config):
|
| 758 |
+
super().__init__()
|
| 759 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 760 |
+
|
| 761 |
+
def forward(self, pooled_output):
|
| 762 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 763 |
+
return seq_relationship_score
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
class BertPreTrainingHeads(nn.Module):
|
| 767 |
+
def __init__(self, config):
|
| 768 |
+
super().__init__()
|
| 769 |
+
self.predictions = BertLMPredictionHead(config)
|
| 770 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 771 |
+
|
| 772 |
+
def forward(self, sequence_output, pooled_output):
|
| 773 |
+
prediction_scores = self.predictions(sequence_output)
|
| 774 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 775 |
+
return prediction_scores, seq_relationship_score
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
@auto_docstring
|
| 779 |
+
class BertPreTrainedModel(PreTrainedModel):
|
| 780 |
+
config: BertConfig
|
| 781 |
+
load_tf_weights = load_tf_weights_in_bert
|
| 782 |
+
base_model_prefix = "bert"
|
| 783 |
+
supports_gradient_checkpointing = True
|
| 784 |
+
_supports_sdpa = True
|
| 785 |
+
|
| 786 |
+
def _init_weights(self, module):
|
| 787 |
+
"""Initialize the weights"""
|
| 788 |
+
if isinstance(module, nn.Linear):
|
| 789 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 790 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 791 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 792 |
+
if module.bias is not None:
|
| 793 |
+
module.bias.data.zero_()
|
| 794 |
+
elif isinstance(module, nn.Embedding):
|
| 795 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 796 |
+
if module.padding_idx is not None:
|
| 797 |
+
module.weight.data[module.padding_idx].zero_()
|
| 798 |
+
elif isinstance(module, nn.LayerNorm):
|
| 799 |
+
module.bias.data.zero_()
|
| 800 |
+
module.weight.data.fill_(1.0)
|
| 801 |
+
elif isinstance(module, BertLMPredictionHead):
|
| 802 |
+
module.bias.data.zero_()
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
@dataclass
|
| 806 |
+
@auto_docstring(
|
| 807 |
+
custom_intro="""
|
| 808 |
+
Output type of [`BertForPreTraining`].
|
| 809 |
+
"""
|
| 810 |
+
)
|
| 811 |
+
class BertForPreTrainingOutput(ModelOutput):
|
| 812 |
+
r"""
|
| 813 |
+
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
| 814 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction
|
| 815 |
+
(classification) loss.
|
| 816 |
+
prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 817 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 818 |
+
seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
|
| 819 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
| 820 |
+
before SoftMax).
|
| 821 |
+
"""
|
| 822 |
+
|
| 823 |
+
loss: Optional[torch.FloatTensor] = None
|
| 824 |
+
prediction_logits: Optional[torch.FloatTensor] = None
|
| 825 |
+
seq_relationship_logits: Optional[torch.FloatTensor] = None
|
| 826 |
+
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
| 827 |
+
attentions: Optional[tuple[torch.FloatTensor]] = None
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
@auto_docstring(
|
| 831 |
+
custom_intro="""
|
| 832 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
| 833 |
+
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
| 834 |
+
all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
| 835 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 836 |
+
|
| 837 |
+
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
| 838 |
+
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
| 839 |
+
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
| 840 |
+
"""
|
| 841 |
+
)
|
| 842 |
+
class BertModel(BertPreTrainedModel):
|
| 843 |
+
_no_split_modules = ["BertEmbeddings", "BertLayer"]
|
| 844 |
+
|
| 845 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 846 |
+
r"""
|
| 847 |
+
add_pooling_layer (bool, *optional*, defaults to `True`):
|
| 848 |
+
Whether to add a pooling layer
|
| 849 |
+
"""
|
| 850 |
+
super().__init__(config)
|
| 851 |
+
self.config = config
|
| 852 |
+
|
| 853 |
+
self.embeddings = BertEmbeddings(config)
|
| 854 |
+
self.encoder = BertEncoder(config)
|
| 855 |
+
|
| 856 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
| 857 |
+
|
| 858 |
+
self.attn_implementation = config._attn_implementation
|
| 859 |
+
self.position_embedding_type = config.position_embedding_type
|
| 860 |
+
|
| 861 |
+
# Initialize weights and apply final processing
|
| 862 |
+
self.post_init()
|
| 863 |
+
|
| 864 |
+
def get_input_embeddings(self):
|
| 865 |
+
return self.embeddings.word_embeddings
|
| 866 |
+
|
| 867 |
+
def set_input_embeddings(self, value):
|
| 868 |
+
self.embeddings.word_embeddings = value
|
| 869 |
+
|
| 870 |
+
def _prune_heads(self, heads_to_prune):
|
| 871 |
+
"""
|
| 872 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 873 |
+
class PreTrainedModel
|
| 874 |
+
"""
|
| 875 |
+
for layer, heads in heads_to_prune.items():
|
| 876 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 877 |
+
|
| 878 |
+
@auto_docstring
|
| 879 |
+
def forward(
|
| 880 |
+
self,
|
| 881 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 882 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 883 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 884 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 885 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 886 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 887 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 888 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 889 |
+
past_key_values: Optional[Cache] = None,
|
| 890 |
+
use_cache: Optional[bool] = None,
|
| 891 |
+
output_attentions: Optional[bool] = None,
|
| 892 |
+
output_hidden_states: Optional[bool] = None,
|
| 893 |
+
return_dict: Optional[bool] = None,
|
| 894 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 895 |
+
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
| 896 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 897 |
+
output_hidden_states = (
|
| 898 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 899 |
+
)
|
| 900 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 901 |
+
|
| 902 |
+
if self.config.is_decoder:
|
| 903 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 904 |
+
else:
|
| 905 |
+
use_cache = False
|
| 906 |
+
|
| 907 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 908 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 909 |
+
elif input_ids is not None:
|
| 910 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 911 |
+
input_shape = input_ids.size()
|
| 912 |
+
elif inputs_embeds is not None:
|
| 913 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 914 |
+
else:
|
| 915 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 916 |
+
|
| 917 |
+
batch_size, seq_length = input_shape
|
| 918 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 919 |
+
|
| 920 |
+
past_key_values_length = 0
|
| 921 |
+
if past_key_values is not None:
|
| 922 |
+
past_key_values_length = (
|
| 923 |
+
past_key_values[0][0].shape[-2]
|
| 924 |
+
if not isinstance(past_key_values, Cache)
|
| 925 |
+
else past_key_values.get_seq_length()
|
| 926 |
+
)
|
| 927 |
+
|
| 928 |
+
if token_type_ids is None:
|
| 929 |
+
if hasattr(self.embeddings, "token_type_ids"):
|
| 930 |
+
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
| 931 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
| 932 |
+
token_type_ids = buffered_token_type_ids_expanded
|
| 933 |
+
else:
|
| 934 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
| 935 |
+
|
| 936 |
+
embedding_output = self.embeddings(
|
| 937 |
+
input_ids=input_ids,
|
| 938 |
+
position_ids=position_ids,
|
| 939 |
+
token_type_ids=token_type_ids,
|
| 940 |
+
inputs_embeds=inputs_embeds,
|
| 941 |
+
past_key_values_length=past_key_values_length,
|
| 942 |
+
)
|
| 943 |
+
|
| 944 |
+
if attention_mask is None:
|
| 945 |
+
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
|
| 946 |
+
|
| 947 |
+
use_sdpa_attention_masks = (
|
| 948 |
+
self.attn_implementation == "sdpa"
|
| 949 |
+
and self.position_embedding_type == "absolute"
|
| 950 |
+
and head_mask is None
|
| 951 |
+
and not output_attentions
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
# Expand the attention mask
|
| 955 |
+
if use_sdpa_attention_masks and attention_mask.dim() == 2:
|
| 956 |
+
# Expand the attention mask for SDPA.
|
| 957 |
+
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
| 958 |
+
if self.config.is_decoder:
|
| 959 |
+
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
| 960 |
+
attention_mask,
|
| 961 |
+
input_shape,
|
| 962 |
+
embedding_output,
|
| 963 |
+
past_key_values_length,
|
| 964 |
+
)
|
| 965 |
+
else:
|
| 966 |
+
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
| 967 |
+
attention_mask, embedding_output.dtype, tgt_len=seq_length
|
| 968 |
+
)
|
| 969 |
+
else:
|
| 970 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 971 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 972 |
+
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
| 973 |
+
|
| 974 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 975 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 976 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
| 977 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 978 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 979 |
+
if encoder_attention_mask is None:
|
| 980 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 981 |
+
|
| 982 |
+
if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
|
| 983 |
+
# Expand the attention mask for SDPA.
|
| 984 |
+
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
| 985 |
+
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
| 986 |
+
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
|
| 987 |
+
)
|
| 988 |
+
else:
|
| 989 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 990 |
+
else:
|
| 991 |
+
encoder_extended_attention_mask = None
|
| 992 |
+
|
| 993 |
+
# Prepare head mask if needed
|
| 994 |
+
# 1.0 in head_mask indicate we keep the head
|
| 995 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 996 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 997 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 998 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 999 |
+
|
| 1000 |
+
encoder_outputs = self.encoder(
|
| 1001 |
+
embedding_output,
|
| 1002 |
+
attention_mask=extended_attention_mask,
|
| 1003 |
+
head_mask=head_mask,
|
| 1004 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1005 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 1006 |
+
past_key_values=past_key_values,
|
| 1007 |
+
use_cache=use_cache,
|
| 1008 |
+
output_attentions=output_attentions,
|
| 1009 |
+
output_hidden_states=output_hidden_states,
|
| 1010 |
+
return_dict=return_dict,
|
| 1011 |
+
cache_position=cache_position,
|
| 1012 |
+
)
|
| 1013 |
+
sequence_output = encoder_outputs[0]
|
| 1014 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 1015 |
+
|
| 1016 |
+
if not return_dict:
|
| 1017 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 1018 |
+
|
| 1019 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 1020 |
+
last_hidden_state=sequence_output,
|
| 1021 |
+
pooler_output=pooled_output,
|
| 1022 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 1023 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1024 |
+
attentions=encoder_outputs.attentions,
|
| 1025 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 1026 |
+
)
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
@auto_docstring(
|
| 1030 |
+
custom_intro="""
|
| 1031 |
+
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
|
| 1032 |
+
sentence prediction (classification)` head.
|
| 1033 |
+
"""
|
| 1034 |
+
)
|
| 1035 |
+
class BertForPreTraining(BertPreTrainedModel):
|
| 1036 |
+
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
| 1037 |
+
|
| 1038 |
+
def __init__(self, config):
|
| 1039 |
+
super().__init__(config)
|
| 1040 |
+
|
| 1041 |
+
self.bert = BertModel(config)
|
| 1042 |
+
self.cls = BertPreTrainingHeads(config)
|
| 1043 |
+
|
| 1044 |
+
# Initialize weights and apply final processing
|
| 1045 |
+
self.post_init()
|
| 1046 |
+
|
| 1047 |
+
def get_output_embeddings(self):
|
| 1048 |
+
return self.cls.predictions.decoder
|
| 1049 |
+
|
| 1050 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1051 |
+
self.cls.predictions.decoder = new_embeddings
|
| 1052 |
+
self.cls.predictions.bias = new_embeddings.bias
|
| 1053 |
+
|
| 1054 |
+
@auto_docstring
|
| 1055 |
+
def forward(
|
| 1056 |
+
self,
|
| 1057 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1058 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1059 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1060 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1061 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1062 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1063 |
+
labels: Optional[torch.Tensor] = None,
|
| 1064 |
+
next_sentence_label: Optional[torch.Tensor] = None,
|
| 1065 |
+
output_attentions: Optional[bool] = None,
|
| 1066 |
+
output_hidden_states: Optional[bool] = None,
|
| 1067 |
+
return_dict: Optional[bool] = None,
|
| 1068 |
+
) -> Union[tuple[torch.Tensor], BertForPreTrainingOutput]:
|
| 1069 |
+
r"""
|
| 1070 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1071 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 1072 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
|
| 1073 |
+
the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 1074 |
+
next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1075 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
|
| 1076 |
+
pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
|
| 1077 |
+
|
| 1078 |
+
- 0 indicates sequence B is a continuation of sequence A,
|
| 1079 |
+
- 1 indicates sequence B is a random sequence.
|
| 1080 |
+
|
| 1081 |
+
Example:
|
| 1082 |
+
|
| 1083 |
+
```python
|
| 1084 |
+
>>> from transformers import AutoTokenizer, BertForPreTraining
|
| 1085 |
+
>>> import torch
|
| 1086 |
+
|
| 1087 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
| 1088 |
+
>>> model = BertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
|
| 1089 |
+
|
| 1090 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 1091 |
+
>>> outputs = model(**inputs)
|
| 1092 |
+
|
| 1093 |
+
>>> prediction_logits = outputs.prediction_logits
|
| 1094 |
+
>>> seq_relationship_logits = outputs.seq_relationship_logits
|
| 1095 |
+
```
|
| 1096 |
+
"""
|
| 1097 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1098 |
+
|
| 1099 |
+
outputs = self.bert(
|
| 1100 |
+
input_ids,
|
| 1101 |
+
attention_mask=attention_mask,
|
| 1102 |
+
token_type_ids=token_type_ids,
|
| 1103 |
+
position_ids=position_ids,
|
| 1104 |
+
head_mask=head_mask,
|
| 1105 |
+
inputs_embeds=inputs_embeds,
|
| 1106 |
+
output_attentions=output_attentions,
|
| 1107 |
+
output_hidden_states=output_hidden_states,
|
| 1108 |
+
return_dict=return_dict,
|
| 1109 |
+
)
|
| 1110 |
+
|
| 1111 |
+
sequence_output, pooled_output = outputs[:2]
|
| 1112 |
+
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
| 1113 |
+
|
| 1114 |
+
total_loss = None
|
| 1115 |
+
if labels is not None and next_sentence_label is not None:
|
| 1116 |
+
loss_fct = CrossEntropyLoss()
|
| 1117 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1118 |
+
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
| 1119 |
+
total_loss = masked_lm_loss + next_sentence_loss
|
| 1120 |
+
|
| 1121 |
+
if not return_dict:
|
| 1122 |
+
output = (prediction_scores, seq_relationship_score) + outputs[2:]
|
| 1123 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1124 |
+
|
| 1125 |
+
return BertForPreTrainingOutput(
|
| 1126 |
+
loss=total_loss,
|
| 1127 |
+
prediction_logits=prediction_scores,
|
| 1128 |
+
seq_relationship_logits=seq_relationship_score,
|
| 1129 |
+
hidden_states=outputs.hidden_states,
|
| 1130 |
+
attentions=outputs.attentions,
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
|
| 1134 |
+
@auto_docstring(
|
| 1135 |
+
custom_intro="""
|
| 1136 |
+
Bert Model with a `language modeling` head on top for CLM fine-tuning.
|
| 1137 |
+
"""
|
| 1138 |
+
)
|
| 1139 |
+
class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
|
| 1140 |
+
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
| 1141 |
+
|
| 1142 |
+
def __init__(self, config):
|
| 1143 |
+
super().__init__(config)
|
| 1144 |
+
|
| 1145 |
+
if not config.is_decoder:
|
| 1146 |
+
logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
|
| 1147 |
+
|
| 1148 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 1149 |
+
self.cls = BertOnlyMLMHead(config)
|
| 1150 |
+
|
| 1151 |
+
# Initialize weights and apply final processing
|
| 1152 |
+
self.post_init()
|
| 1153 |
+
|
| 1154 |
+
def get_output_embeddings(self):
|
| 1155 |
+
return self.cls.predictions.decoder
|
| 1156 |
+
|
| 1157 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1158 |
+
self.cls.predictions.decoder = new_embeddings
|
| 1159 |
+
self.cls.predictions.bias = new_embeddings.bias
|
| 1160 |
+
|
| 1161 |
+
@auto_docstring
|
| 1162 |
+
def forward(
|
| 1163 |
+
self,
|
| 1164 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1165 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1166 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1167 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1168 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1169 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1170 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 1171 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1172 |
+
labels: Optional[torch.Tensor] = None,
|
| 1173 |
+
past_key_values: Optional[Cache] = None,
|
| 1174 |
+
use_cache: Optional[bool] = None,
|
| 1175 |
+
output_attentions: Optional[bool] = None,
|
| 1176 |
+
output_hidden_states: Optional[bool] = None,
|
| 1177 |
+
return_dict: Optional[bool] = None,
|
| 1178 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 1179 |
+
**loss_kwargs,
|
| 1180 |
+
) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
| 1181 |
+
r"""
|
| 1182 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1183 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
| 1184 |
+
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
| 1185 |
+
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
|
| 1186 |
+
"""
|
| 1187 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1188 |
+
if labels is not None:
|
| 1189 |
+
use_cache = False
|
| 1190 |
+
|
| 1191 |
+
outputs = self.bert(
|
| 1192 |
+
input_ids,
|
| 1193 |
+
attention_mask=attention_mask,
|
| 1194 |
+
token_type_ids=token_type_ids,
|
| 1195 |
+
position_ids=position_ids,
|
| 1196 |
+
head_mask=head_mask,
|
| 1197 |
+
inputs_embeds=inputs_embeds,
|
| 1198 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1199 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1200 |
+
past_key_values=past_key_values,
|
| 1201 |
+
use_cache=use_cache,
|
| 1202 |
+
output_attentions=output_attentions,
|
| 1203 |
+
output_hidden_states=output_hidden_states,
|
| 1204 |
+
return_dict=return_dict,
|
| 1205 |
+
cache_position=cache_position,
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
+
sequence_output = outputs[0]
|
| 1209 |
+
prediction_scores = self.cls(sequence_output)
|
| 1210 |
+
|
| 1211 |
+
lm_loss = None
|
| 1212 |
+
if labels is not None:
|
| 1213 |
+
lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs)
|
| 1214 |
+
|
| 1215 |
+
if not return_dict:
|
| 1216 |
+
output = (prediction_scores,) + outputs[2:]
|
| 1217 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
| 1218 |
+
|
| 1219 |
+
return CausalLMOutputWithCrossAttentions(
|
| 1220 |
+
loss=lm_loss,
|
| 1221 |
+
logits=prediction_scores,
|
| 1222 |
+
past_key_values=outputs.past_key_values,
|
| 1223 |
+
hidden_states=outputs.hidden_states,
|
| 1224 |
+
attentions=outputs.attentions,
|
| 1225 |
+
cross_attentions=outputs.cross_attentions,
|
| 1226 |
+
)
|
| 1227 |
+
|
| 1228 |
+
|
| 1229 |
+
@auto_docstring
|
| 1230 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
| 1231 |
+
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
| 1232 |
+
|
| 1233 |
+
def __init__(self, config):
|
| 1234 |
+
super().__init__(config)
|
| 1235 |
+
|
| 1236 |
+
if config.is_decoder:
|
| 1237 |
+
logger.warning(
|
| 1238 |
+
"If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
|
| 1239 |
+
"bi-directional self-attention."
|
| 1240 |
+
)
|
| 1241 |
+
|
| 1242 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 1243 |
+
self.cls = BertOnlyMLMHead(config)
|
| 1244 |
+
|
| 1245 |
+
# Initialize weights and apply final processing
|
| 1246 |
+
self.post_init()
|
| 1247 |
+
|
| 1248 |
+
def get_output_embeddings(self):
|
| 1249 |
+
return self.cls.predictions.decoder
|
| 1250 |
+
|
| 1251 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1252 |
+
self.cls.predictions.decoder = new_embeddings
|
| 1253 |
+
self.cls.predictions.bias = new_embeddings.bias
|
| 1254 |
+
|
| 1255 |
+
@auto_docstring
|
| 1256 |
+
def forward(
|
| 1257 |
+
self,
|
| 1258 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1259 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1260 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1261 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1262 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1263 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1264 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 1265 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1266 |
+
labels: Optional[torch.Tensor] = None,
|
| 1267 |
+
output_attentions: Optional[bool] = None,
|
| 1268 |
+
output_hidden_states: Optional[bool] = None,
|
| 1269 |
+
return_dict: Optional[bool] = None,
|
| 1270 |
+
) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
|
| 1271 |
+
r"""
|
| 1272 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1273 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 1274 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 1275 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 1276 |
+
"""
|
| 1277 |
+
|
| 1278 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1279 |
+
|
| 1280 |
+
outputs = self.bert(
|
| 1281 |
+
input_ids,
|
| 1282 |
+
attention_mask=attention_mask,
|
| 1283 |
+
token_type_ids=token_type_ids,
|
| 1284 |
+
position_ids=position_ids,
|
| 1285 |
+
head_mask=head_mask,
|
| 1286 |
+
inputs_embeds=inputs_embeds,
|
| 1287 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1288 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1289 |
+
output_attentions=output_attentions,
|
| 1290 |
+
output_hidden_states=output_hidden_states,
|
| 1291 |
+
return_dict=return_dict,
|
| 1292 |
+
)
|
| 1293 |
+
|
| 1294 |
+
sequence_output = outputs[0]
|
| 1295 |
+
prediction_scores = self.cls(sequence_output)
|
| 1296 |
+
|
| 1297 |
+
masked_lm_loss = None
|
| 1298 |
+
if labels is not None:
|
| 1299 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
| 1300 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1301 |
+
|
| 1302 |
+
if not return_dict:
|
| 1303 |
+
output = (prediction_scores,) + outputs[2:]
|
| 1304 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 1305 |
+
|
| 1306 |
+
return MaskedLMOutput(
|
| 1307 |
+
loss=masked_lm_loss,
|
| 1308 |
+
logits=prediction_scores,
|
| 1309 |
+
hidden_states=outputs.hidden_states,
|
| 1310 |
+
attentions=outputs.attentions,
|
| 1311 |
+
)
|
| 1312 |
+
|
| 1313 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
| 1314 |
+
input_shape = input_ids.shape
|
| 1315 |
+
effective_batch_size = input_shape[0]
|
| 1316 |
+
|
| 1317 |
+
# add a dummy token
|
| 1318 |
+
if self.config.pad_token_id is None:
|
| 1319 |
+
raise ValueError("The PAD token should be defined for generation")
|
| 1320 |
+
|
| 1321 |
+
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
| 1322 |
+
dummy_token = torch.full(
|
| 1323 |
+
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
|
| 1324 |
+
)
|
| 1325 |
+
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
| 1326 |
+
|
| 1327 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 1328 |
+
|
| 1329 |
+
@classmethod
|
| 1330 |
+
def can_generate(cls) -> bool:
|
| 1331 |
+
"""
|
| 1332 |
+
Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
|
| 1333 |
+
`prepare_inputs_for_generation` method.
|
| 1334 |
+
"""
|
| 1335 |
+
return False
|
| 1336 |
+
|
| 1337 |
+
|
| 1338 |
+
@auto_docstring(
|
| 1339 |
+
custom_intro="""
|
| 1340 |
+
Bert Model with a `next sentence prediction (classification)` head on top.
|
| 1341 |
+
"""
|
| 1342 |
+
)
|
| 1343 |
+
class BertForNextSentencePrediction(BertPreTrainedModel):
|
| 1344 |
+
def __init__(self, config):
|
| 1345 |
+
super().__init__(config)
|
| 1346 |
+
|
| 1347 |
+
self.bert = BertModel(config)
|
| 1348 |
+
self.cls = BertOnlyNSPHead(config)
|
| 1349 |
+
|
| 1350 |
+
# Initialize weights and apply final processing
|
| 1351 |
+
self.post_init()
|
| 1352 |
+
|
| 1353 |
+
@auto_docstring
|
| 1354 |
+
def forward(
|
| 1355 |
+
self,
|
| 1356 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1357 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1358 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1359 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1360 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1361 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1362 |
+
labels: Optional[torch.Tensor] = None,
|
| 1363 |
+
output_attentions: Optional[bool] = None,
|
| 1364 |
+
output_hidden_states: Optional[bool] = None,
|
| 1365 |
+
return_dict: Optional[bool] = None,
|
| 1366 |
+
**kwargs,
|
| 1367 |
+
) -> Union[tuple[torch.Tensor], NextSentencePredictorOutput]:
|
| 1368 |
+
r"""
|
| 1369 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1370 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
| 1371 |
+
(see `input_ids` docstring). Indices should be in `[0, 1]`:
|
| 1372 |
+
|
| 1373 |
+
- 0 indicates sequence B is a continuation of sequence A,
|
| 1374 |
+
- 1 indicates sequence B is a random sequence.
|
| 1375 |
+
|
| 1376 |
+
Example:
|
| 1377 |
+
|
| 1378 |
+
```python
|
| 1379 |
+
>>> from transformers import AutoTokenizer, BertForNextSentencePrediction
|
| 1380 |
+
>>> import torch
|
| 1381 |
+
|
| 1382 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
| 1383 |
+
>>> model = BertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
|
| 1384 |
+
|
| 1385 |
+
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
| 1386 |
+
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
| 1387 |
+
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
|
| 1388 |
+
|
| 1389 |
+
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
|
| 1390 |
+
>>> logits = outputs.logits
|
| 1391 |
+
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
| 1392 |
+
```
|
| 1393 |
+
"""
|
| 1394 |
+
|
| 1395 |
+
if "next_sentence_label" in kwargs:
|
| 1396 |
+
warnings.warn(
|
| 1397 |
+
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
|
| 1398 |
+
" `labels` instead.",
|
| 1399 |
+
FutureWarning,
|
| 1400 |
+
)
|
| 1401 |
+
labels = kwargs.pop("next_sentence_label")
|
| 1402 |
+
|
| 1403 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1404 |
+
|
| 1405 |
+
outputs = self.bert(
|
| 1406 |
+
input_ids,
|
| 1407 |
+
attention_mask=attention_mask,
|
| 1408 |
+
token_type_ids=token_type_ids,
|
| 1409 |
+
position_ids=position_ids,
|
| 1410 |
+
head_mask=head_mask,
|
| 1411 |
+
inputs_embeds=inputs_embeds,
|
| 1412 |
+
output_attentions=output_attentions,
|
| 1413 |
+
output_hidden_states=output_hidden_states,
|
| 1414 |
+
return_dict=return_dict,
|
| 1415 |
+
)
|
| 1416 |
+
|
| 1417 |
+
pooled_output = outputs[1]
|
| 1418 |
+
|
| 1419 |
+
seq_relationship_scores = self.cls(pooled_output)
|
| 1420 |
+
|
| 1421 |
+
next_sentence_loss = None
|
| 1422 |
+
if labels is not None:
|
| 1423 |
+
loss_fct = CrossEntropyLoss()
|
| 1424 |
+
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
|
| 1425 |
+
|
| 1426 |
+
if not return_dict:
|
| 1427 |
+
output = (seq_relationship_scores,) + outputs[2:]
|
| 1428 |
+
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
|
| 1429 |
+
|
| 1430 |
+
return NextSentencePredictorOutput(
|
| 1431 |
+
loss=next_sentence_loss,
|
| 1432 |
+
logits=seq_relationship_scores,
|
| 1433 |
+
hidden_states=outputs.hidden_states,
|
| 1434 |
+
attentions=outputs.attentions,
|
| 1435 |
+
)
|
| 1436 |
+
|
| 1437 |
+
|
| 1438 |
+
@auto_docstring(
|
| 1439 |
+
custom_intro="""
|
| 1440 |
+
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
| 1441 |
+
output) e.g. for GLUE tasks.
|
| 1442 |
+
"""
|
| 1443 |
+
)
|
| 1444 |
+
class BertForSequenceClassification(BertPreTrainedModel):
|
| 1445 |
+
def __init__(self, config):
|
| 1446 |
+
super().__init__(config)
|
| 1447 |
+
self.num_labels = config.num_labels
|
| 1448 |
+
self.config = config
|
| 1449 |
+
|
| 1450 |
+
self.bert = BertModel(config)
|
| 1451 |
+
classifier_dropout = (
|
| 1452 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 1453 |
+
)
|
| 1454 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 1455 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1456 |
+
|
| 1457 |
+
# Initialize weights and apply final processing
|
| 1458 |
+
self.post_init()
|
| 1459 |
+
|
| 1460 |
+
@auto_docstring
|
| 1461 |
+
def forward(
|
| 1462 |
+
self,
|
| 1463 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1464 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1465 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1466 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1467 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1468 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1469 |
+
labels: Optional[torch.Tensor] = None,
|
| 1470 |
+
output_attentions: Optional[bool] = None,
|
| 1471 |
+
output_hidden_states: Optional[bool] = None,
|
| 1472 |
+
return_dict: Optional[bool] = None,
|
| 1473 |
+
) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 1474 |
+
r"""
|
| 1475 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1476 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1477 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1478 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1479 |
+
"""
|
| 1480 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1481 |
+
|
| 1482 |
+
outputs = self.bert(
|
| 1483 |
+
input_ids,
|
| 1484 |
+
attention_mask=attention_mask,
|
| 1485 |
+
token_type_ids=token_type_ids,
|
| 1486 |
+
position_ids=position_ids,
|
| 1487 |
+
head_mask=head_mask,
|
| 1488 |
+
inputs_embeds=inputs_embeds,
|
| 1489 |
+
output_attentions=output_attentions,
|
| 1490 |
+
output_hidden_states=output_hidden_states,
|
| 1491 |
+
return_dict=return_dict,
|
| 1492 |
+
)
|
| 1493 |
+
|
| 1494 |
+
pooled_output = outputs[1]
|
| 1495 |
+
|
| 1496 |
+
pooled_output = self.dropout(pooled_output)
|
| 1497 |
+
logits = self.classifier(pooled_output)
|
| 1498 |
+
|
| 1499 |
+
loss = None
|
| 1500 |
+
if labels is not None:
|
| 1501 |
+
if self.config.problem_type is None:
|
| 1502 |
+
if self.num_labels == 1:
|
| 1503 |
+
self.config.problem_type = "regression"
|
| 1504 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1505 |
+
self.config.problem_type = "single_label_classification"
|
| 1506 |
+
else:
|
| 1507 |
+
self.config.problem_type = "multi_label_classification"
|
| 1508 |
+
|
| 1509 |
+
if self.config.problem_type == "regression":
|
| 1510 |
+
loss_fct = MSELoss()
|
| 1511 |
+
if self.num_labels == 1:
|
| 1512 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 1513 |
+
else:
|
| 1514 |
+
loss = loss_fct(logits, labels)
|
| 1515 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1516 |
+
loss_fct = CrossEntropyLoss()
|
| 1517 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1518 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1519 |
+
loss_fct = BCEWithLogitsLoss()
|
| 1520 |
+
loss = loss_fct(logits, labels)
|
| 1521 |
+
if not return_dict:
|
| 1522 |
+
output = (logits,) + outputs[2:]
|
| 1523 |
+
return ((loss,) + output) if loss is not None else output
|
| 1524 |
+
|
| 1525 |
+
return SequenceClassifierOutput(
|
| 1526 |
+
loss=loss,
|
| 1527 |
+
logits=logits,
|
| 1528 |
+
hidden_states=outputs.hidden_states,
|
| 1529 |
+
attentions=outputs.attentions,
|
| 1530 |
+
)
|
| 1531 |
+
|
| 1532 |
+
|
| 1533 |
+
@auto_docstring
|
| 1534 |
+
class BertForMultipleChoice(BertPreTrainedModel):
|
| 1535 |
+
def __init__(self, config):
|
| 1536 |
+
super().__init__(config)
|
| 1537 |
+
|
| 1538 |
+
self.bert = BertModel(config)
|
| 1539 |
+
classifier_dropout = (
|
| 1540 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 1541 |
+
)
|
| 1542 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 1543 |
+
self.classifier = nn.Linear(config.hidden_size, 1)
|
| 1544 |
+
|
| 1545 |
+
# Initialize weights and apply final processing
|
| 1546 |
+
self.post_init()
|
| 1547 |
+
|
| 1548 |
+
@auto_docstring
|
| 1549 |
+
def forward(
|
| 1550 |
+
self,
|
| 1551 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1552 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1553 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1554 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1555 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1556 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1557 |
+
labels: Optional[torch.Tensor] = None,
|
| 1558 |
+
output_attentions: Optional[bool] = None,
|
| 1559 |
+
output_hidden_states: Optional[bool] = None,
|
| 1560 |
+
return_dict: Optional[bool] = None,
|
| 1561 |
+
) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
| 1562 |
+
r"""
|
| 1563 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
|
| 1564 |
+
Indices of input sequence tokens in the vocabulary.
|
| 1565 |
+
|
| 1566 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1567 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 1568 |
+
|
| 1569 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1570 |
+
token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
|
| 1571 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 1572 |
+
1]`:
|
| 1573 |
+
|
| 1574 |
+
- 0 corresponds to a *sentence A* token,
|
| 1575 |
+
- 1 corresponds to a *sentence B* token.
|
| 1576 |
+
|
| 1577 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 1578 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
|
| 1579 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 1580 |
+
config.max_position_embeddings - 1]`.
|
| 1581 |
+
|
| 1582 |
+
[What are position IDs?](../glossary#position-ids)
|
| 1583 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
|
| 1584 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 1585 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 1586 |
+
model's internal embedding lookup matrix.
|
| 1587 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1588 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
| 1589 |
+
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
| 1590 |
+
`input_ids` above)
|
| 1591 |
+
"""
|
| 1592 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1593 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
| 1594 |
+
|
| 1595 |
+
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
| 1596 |
+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 1597 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
| 1598 |
+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
| 1599 |
+
inputs_embeds = (
|
| 1600 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
| 1601 |
+
if inputs_embeds is not None
|
| 1602 |
+
else None
|
| 1603 |
+
)
|
| 1604 |
+
|
| 1605 |
+
outputs = self.bert(
|
| 1606 |
+
input_ids,
|
| 1607 |
+
attention_mask=attention_mask,
|
| 1608 |
+
token_type_ids=token_type_ids,
|
| 1609 |
+
position_ids=position_ids,
|
| 1610 |
+
head_mask=head_mask,
|
| 1611 |
+
inputs_embeds=inputs_embeds,
|
| 1612 |
+
output_attentions=output_attentions,
|
| 1613 |
+
output_hidden_states=output_hidden_states,
|
| 1614 |
+
return_dict=return_dict,
|
| 1615 |
+
)
|
| 1616 |
+
|
| 1617 |
+
pooled_output = outputs[1]
|
| 1618 |
+
|
| 1619 |
+
pooled_output = self.dropout(pooled_output)
|
| 1620 |
+
logits = self.classifier(pooled_output)
|
| 1621 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 1622 |
+
|
| 1623 |
+
loss = None
|
| 1624 |
+
if labels is not None:
|
| 1625 |
+
loss_fct = CrossEntropyLoss()
|
| 1626 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 1627 |
+
|
| 1628 |
+
if not return_dict:
|
| 1629 |
+
output = (reshaped_logits,) + outputs[2:]
|
| 1630 |
+
return ((loss,) + output) if loss is not None else output
|
| 1631 |
+
|
| 1632 |
+
return MultipleChoiceModelOutput(
|
| 1633 |
+
loss=loss,
|
| 1634 |
+
logits=reshaped_logits,
|
| 1635 |
+
hidden_states=outputs.hidden_states,
|
| 1636 |
+
attentions=outputs.attentions,
|
| 1637 |
+
)
|
| 1638 |
+
|
| 1639 |
+
|
| 1640 |
+
@auto_docstring
|
| 1641 |
+
class BertForTokenClassification(BertPreTrainedModel):
|
| 1642 |
+
def __init__(self, config):
|
| 1643 |
+
super().__init__(config)
|
| 1644 |
+
self.num_labels = config.num_labels
|
| 1645 |
+
|
| 1646 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 1647 |
+
classifier_dropout = (
|
| 1648 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 1649 |
+
)
|
| 1650 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 1651 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1652 |
+
|
| 1653 |
+
# Initialize weights and apply final processing
|
| 1654 |
+
self.post_init()
|
| 1655 |
+
|
| 1656 |
+
@auto_docstring
|
| 1657 |
+
def forward(
|
| 1658 |
+
self,
|
| 1659 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1660 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1661 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1662 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1663 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1664 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1665 |
+
labels: Optional[torch.Tensor] = None,
|
| 1666 |
+
output_attentions: Optional[bool] = None,
|
| 1667 |
+
output_hidden_states: Optional[bool] = None,
|
| 1668 |
+
return_dict: Optional[bool] = None,
|
| 1669 |
+
) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
|
| 1670 |
+
r"""
|
| 1671 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1672 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
| 1673 |
+
"""
|
| 1674 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1675 |
+
|
| 1676 |
+
outputs = self.bert(
|
| 1677 |
+
input_ids,
|
| 1678 |
+
attention_mask=attention_mask,
|
| 1679 |
+
token_type_ids=token_type_ids,
|
| 1680 |
+
position_ids=position_ids,
|
| 1681 |
+
head_mask=head_mask,
|
| 1682 |
+
inputs_embeds=inputs_embeds,
|
| 1683 |
+
output_attentions=output_attentions,
|
| 1684 |
+
output_hidden_states=output_hidden_states,
|
| 1685 |
+
return_dict=return_dict,
|
| 1686 |
+
)
|
| 1687 |
+
|
| 1688 |
+
sequence_output = outputs[0]
|
| 1689 |
+
|
| 1690 |
+
sequence_output = self.dropout(sequence_output)
|
| 1691 |
+
logits = self.classifier(sequence_output)
|
| 1692 |
+
|
| 1693 |
+
loss = None
|
| 1694 |
+
if labels is not None:
|
| 1695 |
+
loss_fct = CrossEntropyLoss()
|
| 1696 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1697 |
+
|
| 1698 |
+
if not return_dict:
|
| 1699 |
+
output = (logits,) + outputs[2:]
|
| 1700 |
+
return ((loss,) + output) if loss is not None else output
|
| 1701 |
+
|
| 1702 |
+
return TokenClassifierOutput(
|
| 1703 |
+
loss=loss,
|
| 1704 |
+
logits=logits,
|
| 1705 |
+
hidden_states=outputs.hidden_states,
|
| 1706 |
+
attentions=outputs.attentions,
|
| 1707 |
+
)
|
| 1708 |
+
|
| 1709 |
+
|
| 1710 |
+
@auto_docstring
|
| 1711 |
+
class BertForQuestionAnswering(BertPreTrainedModel):
|
| 1712 |
+
def __init__(self, config):
|
| 1713 |
+
super().__init__(config)
|
| 1714 |
+
self.num_labels = config.num_labels
|
| 1715 |
+
|
| 1716 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 1717 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 1718 |
+
|
| 1719 |
+
# Initialize weights and apply final processing
|
| 1720 |
+
self.post_init()
|
| 1721 |
+
|
| 1722 |
+
@auto_docstring
|
| 1723 |
+
def forward(
|
| 1724 |
+
self,
|
| 1725 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1726 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1727 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1728 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1729 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1730 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1731 |
+
start_positions: Optional[torch.Tensor] = None,
|
| 1732 |
+
end_positions: Optional[torch.Tensor] = None,
|
| 1733 |
+
output_attentions: Optional[bool] = None,
|
| 1734 |
+
output_hidden_states: Optional[bool] = None,
|
| 1735 |
+
return_dict: Optional[bool] = None,
|
| 1736 |
+
) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
| 1737 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1738 |
+
|
| 1739 |
+
outputs = self.bert(
|
| 1740 |
+
input_ids,
|
| 1741 |
+
attention_mask=attention_mask,
|
| 1742 |
+
token_type_ids=token_type_ids,
|
| 1743 |
+
position_ids=position_ids,
|
| 1744 |
+
head_mask=head_mask,
|
| 1745 |
+
inputs_embeds=inputs_embeds,
|
| 1746 |
+
output_attentions=output_attentions,
|
| 1747 |
+
output_hidden_states=output_hidden_states,
|
| 1748 |
+
return_dict=return_dict,
|
| 1749 |
+
)
|
| 1750 |
+
|
| 1751 |
+
sequence_output = outputs[0]
|
| 1752 |
+
|
| 1753 |
+
logits = self.qa_outputs(sequence_output)
|
| 1754 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1755 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1756 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1757 |
+
|
| 1758 |
+
total_loss = None
|
| 1759 |
+
if start_positions is not None and end_positions is not None:
|
| 1760 |
+
# If we are on multi-GPU, split add a dimension
|
| 1761 |
+
if len(start_positions.size()) > 1:
|
| 1762 |
+
start_positions = start_positions.squeeze(-1)
|
| 1763 |
+
if len(end_positions.size()) > 1:
|
| 1764 |
+
end_positions = end_positions.squeeze(-1)
|
| 1765 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1766 |
+
ignored_index = start_logits.size(1)
|
| 1767 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1768 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1769 |
+
|
| 1770 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1771 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1772 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1773 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1774 |
+
|
| 1775 |
+
if not return_dict:
|
| 1776 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 1777 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1778 |
+
|
| 1779 |
+
return QuestionAnsweringModelOutput(
|
| 1780 |
+
loss=total_loss,
|
| 1781 |
+
start_logits=start_logits,
|
| 1782 |
+
end_logits=end_logits,
|
| 1783 |
+
hidden_states=outputs.hidden_states,
|
| 1784 |
+
attentions=outputs.attentions,
|
| 1785 |
+
)
|
| 1786 |
+
|
| 1787 |
+
|
| 1788 |
+
__all__ = [
|
| 1789 |
+
"BertForMaskedLM",
|
| 1790 |
+
"BertForMultipleChoice",
|
| 1791 |
+
"BertForNextSentencePrediction",
|
| 1792 |
+
"BertForPreTraining",
|
| 1793 |
+
"BertForQuestionAnswering",
|
| 1794 |
+
"BertForSequenceClassification",
|
| 1795 |
+
"BertForTokenClassification",
|
| 1796 |
+
"BertLayer",
|
| 1797 |
+
"BertLMHeadModel",
|
| 1798 |
+
"BertModel",
|
| 1799 |
+
"BertPreTrainedModel",
|
| 1800 |
+
"load_tf_weights_in_bert",
|
| 1801 |
+
]
|
URSA/.venv_ursa/lib/python3.12/site-packages/transformers/models/bert/modeling_flax_bert.py
ADDED
|
@@ -0,0 +1,1727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Callable, Optional
|
| 17 |
+
|
| 18 |
+
import flax
|
| 19 |
+
import flax.linen as nn
|
| 20 |
+
import jax
|
| 21 |
+
import jax.numpy as jnp
|
| 22 |
+
import numpy as np
|
| 23 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
| 24 |
+
from flax.linen import combine_masks, make_causal_mask
|
| 25 |
+
from flax.linen import partitioning as nn_partitioning
|
| 26 |
+
from flax.linen.attention import dot_product_attention_weights
|
| 27 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
| 28 |
+
from jax import lax
|
| 29 |
+
|
| 30 |
+
from ...modeling_flax_outputs import (
|
| 31 |
+
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
| 32 |
+
FlaxBaseModelOutputWithPooling,
|
| 33 |
+
FlaxBaseModelOutputWithPoolingAndCrossAttentions,
|
| 34 |
+
FlaxCausalLMOutputWithCrossAttentions,
|
| 35 |
+
FlaxMaskedLMOutput,
|
| 36 |
+
FlaxMultipleChoiceModelOutput,
|
| 37 |
+
FlaxNextSentencePredictorOutput,
|
| 38 |
+
FlaxQuestionAnsweringModelOutput,
|
| 39 |
+
FlaxSequenceClassifierOutput,
|
| 40 |
+
FlaxTokenClassifierOutput,
|
| 41 |
+
)
|
| 42 |
+
from ...modeling_flax_utils import (
|
| 43 |
+
ACT2FN,
|
| 44 |
+
FlaxPreTrainedModel,
|
| 45 |
+
append_call_sample_docstring,
|
| 46 |
+
append_replace_return_docstrings,
|
| 47 |
+
overwrite_call_docstring,
|
| 48 |
+
)
|
| 49 |
+
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
| 50 |
+
from .configuration_bert import BertConfig
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
logger = logging.get_logger(__name__)
|
| 54 |
+
|
| 55 |
+
_CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
|
| 56 |
+
_CONFIG_FOR_DOC = "BertConfig"
|
| 57 |
+
|
| 58 |
+
remat = nn_partitioning.remat
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@flax.struct.dataclass
|
| 62 |
+
class FlaxBertForPreTrainingOutput(ModelOutput):
|
| 63 |
+
"""
|
| 64 |
+
Output type of [`BertForPreTraining`].
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 68 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 69 |
+
seq_relationship_logits (`jnp.ndarray` of shape `(batch_size, 2)`):
|
| 70 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
| 71 |
+
before SoftMax).
|
| 72 |
+
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 73 |
+
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 74 |
+
`(batch_size, sequence_length, hidden_size)`.
|
| 75 |
+
|
| 76 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 77 |
+
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 78 |
+
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 79 |
+
sequence_length)`.
|
| 80 |
+
|
| 81 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 82 |
+
heads.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
prediction_logits: jnp.ndarray = None
|
| 86 |
+
seq_relationship_logits: jnp.ndarray = None
|
| 87 |
+
hidden_states: Optional[tuple[jnp.ndarray]] = None
|
| 88 |
+
attentions: Optional[tuple[jnp.ndarray]] = None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
BERT_START_DOCSTRING = r"""
|
| 92 |
+
|
| 93 |
+
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 94 |
+
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
|
| 95 |
+
|
| 96 |
+
This model is also a
|
| 97 |
+
[flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
|
| 98 |
+
a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
|
| 99 |
+
behavior.
|
| 100 |
+
|
| 101 |
+
Finally, this model supports inherent JAX features such as:
|
| 102 |
+
|
| 103 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
| 104 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
| 105 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
| 106 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
| 107 |
+
|
| 108 |
+
Parameters:
|
| 109 |
+
config ([`BertConfig`]): Model configuration class with all the parameters of the model.
|
| 110 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 111 |
+
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
| 112 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
| 113 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
| 114 |
+
`jax.numpy.bfloat16` (on TPUs).
|
| 115 |
+
|
| 116 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
| 117 |
+
specified all the computation will be performed with the given `dtype`.
|
| 118 |
+
|
| 119 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
| 120 |
+
parameters.**
|
| 121 |
+
|
| 122 |
+
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
| 123 |
+
[`~FlaxPreTrainedModel.to_bf16`].
|
| 124 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
| 125 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
| 126 |
+
`jax.numpy.bfloat16` (on TPUs).
|
| 127 |
+
|
| 128 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
| 129 |
+
specified all the computation will be performed with the given `dtype`.
|
| 130 |
+
|
| 131 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
| 132 |
+
parameters.**
|
| 133 |
+
|
| 134 |
+
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
| 135 |
+
[`~FlaxPreTrainedModel.to_bf16`].
|
| 136 |
+
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
BERT_INPUTS_DOCSTRING = r"""
|
| 140 |
+
Args:
|
| 141 |
+
input_ids (`numpy.ndarray` of shape `({0})`):
|
| 142 |
+
Indices of input sequence tokens in the vocabulary.
|
| 143 |
+
|
| 144 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 145 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 146 |
+
|
| 147 |
+
[What are input IDs?](../glossary#input-ids)
|
| 148 |
+
attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
|
| 149 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 150 |
+
|
| 151 |
+
- 1 for tokens that are **not masked**,
|
| 152 |
+
- 0 for tokens that are **masked**.
|
| 153 |
+
|
| 154 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 155 |
+
token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
|
| 156 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 157 |
+
1]`:
|
| 158 |
+
|
| 159 |
+
- 0 corresponds to a *sentence A* token,
|
| 160 |
+
- 1 corresponds to a *sentence B* token.
|
| 161 |
+
|
| 162 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 163 |
+
position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
|
| 164 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 165 |
+
config.max_position_embeddings - 1]`.
|
| 166 |
+
head_mask (`numpy.ndarray` of shape `({0})`, `optional):
|
| 167 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
| 168 |
+
|
| 169 |
+
- 1 indicates the head is **not masked**,
|
| 170 |
+
- 0 indicates the head is **masked**.
|
| 171 |
+
|
| 172 |
+
return_dict (`bool`, *optional*):
|
| 173 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 174 |
+
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class FlaxBertEmbeddings(nn.Module):
|
| 179 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 180 |
+
|
| 181 |
+
config: BertConfig
|
| 182 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 183 |
+
|
| 184 |
+
def setup(self):
|
| 185 |
+
self.word_embeddings = nn.Embed(
|
| 186 |
+
self.config.vocab_size,
|
| 187 |
+
self.config.hidden_size,
|
| 188 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
| 189 |
+
dtype=self.dtype,
|
| 190 |
+
)
|
| 191 |
+
self.position_embeddings = nn.Embed(
|
| 192 |
+
self.config.max_position_embeddings,
|
| 193 |
+
self.config.hidden_size,
|
| 194 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
| 195 |
+
dtype=self.dtype,
|
| 196 |
+
)
|
| 197 |
+
self.token_type_embeddings = nn.Embed(
|
| 198 |
+
self.config.type_vocab_size,
|
| 199 |
+
self.config.hidden_size,
|
| 200 |
+
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
| 201 |
+
dtype=self.dtype,
|
| 202 |
+
)
|
| 203 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 204 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 205 |
+
|
| 206 |
+
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
| 207 |
+
# Embed
|
| 208 |
+
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
|
| 209 |
+
position_embeds = self.position_embeddings(position_ids.astype("i4"))
|
| 210 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
|
| 211 |
+
|
| 212 |
+
# Sum all embeddings
|
| 213 |
+
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
|
| 214 |
+
|
| 215 |
+
# Layer Norm
|
| 216 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 217 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 218 |
+
return hidden_states
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class FlaxBertSelfAttention(nn.Module):
|
| 222 |
+
config: BertConfig
|
| 223 |
+
causal: bool = False
|
| 224 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 225 |
+
|
| 226 |
+
def setup(self):
|
| 227 |
+
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
| 228 |
+
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
| 229 |
+
raise ValueError(
|
| 230 |
+
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
|
| 231 |
+
" : {self.config.num_attention_heads}"
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
self.query = nn.Dense(
|
| 235 |
+
self.config.hidden_size,
|
| 236 |
+
dtype=self.dtype,
|
| 237 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 238 |
+
)
|
| 239 |
+
self.key = nn.Dense(
|
| 240 |
+
self.config.hidden_size,
|
| 241 |
+
dtype=self.dtype,
|
| 242 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 243 |
+
)
|
| 244 |
+
self.value = nn.Dense(
|
| 245 |
+
self.config.hidden_size,
|
| 246 |
+
dtype=self.dtype,
|
| 247 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
if self.causal:
|
| 251 |
+
self.causal_mask = make_causal_mask(
|
| 252 |
+
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def _split_heads(self, hidden_states):
|
| 256 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
|
| 257 |
+
|
| 258 |
+
def _merge_heads(self, hidden_states):
|
| 259 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
|
| 260 |
+
|
| 261 |
+
@nn.compact
|
| 262 |
+
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
|
| 263 |
+
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
| 264 |
+
"""
|
| 265 |
+
This function takes projected key, value states from a single input token and concatenates the states to cached
|
| 266 |
+
states from previous steps. This function is slightly adapted from the official Flax repository:
|
| 267 |
+
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
| 268 |
+
"""
|
| 269 |
+
# detect if we're initializing by absence of existing cache data.
|
| 270 |
+
is_initialized = self.has_variable("cache", "cached_key")
|
| 271 |
+
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
| 272 |
+
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
| 273 |
+
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
| 274 |
+
|
| 275 |
+
if is_initialized:
|
| 276 |
+
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
| 277 |
+
# update key, value caches with our new 1d spatial slices
|
| 278 |
+
cur_index = cache_index.value
|
| 279 |
+
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
| 280 |
+
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
| 281 |
+
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
| 282 |
+
cached_key.value = key
|
| 283 |
+
cached_value.value = value
|
| 284 |
+
num_updated_cache_vectors = query.shape[1]
|
| 285 |
+
cache_index.value = cache_index.value + num_updated_cache_vectors
|
| 286 |
+
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
|
| 287 |
+
pad_mask = jnp.broadcast_to(
|
| 288 |
+
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
| 289 |
+
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
| 290 |
+
)
|
| 291 |
+
attention_mask = combine_masks(pad_mask, attention_mask)
|
| 292 |
+
return key, value, attention_mask
|
| 293 |
+
|
| 294 |
+
def __call__(
|
| 295 |
+
self,
|
| 296 |
+
hidden_states,
|
| 297 |
+
attention_mask,
|
| 298 |
+
layer_head_mask,
|
| 299 |
+
key_value_states: Optional[jnp.ndarray] = None,
|
| 300 |
+
init_cache: bool = False,
|
| 301 |
+
deterministic=True,
|
| 302 |
+
output_attentions: bool = False,
|
| 303 |
+
):
|
| 304 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 305 |
+
# for the decoder
|
| 306 |
+
is_cross_attention = key_value_states is not None
|
| 307 |
+
batch_size = hidden_states.shape[0]
|
| 308 |
+
|
| 309 |
+
# get query proj
|
| 310 |
+
query_states = self.query(hidden_states)
|
| 311 |
+
# get key, value proj
|
| 312 |
+
if is_cross_attention:
|
| 313 |
+
# cross_attentions
|
| 314 |
+
key_states = self.key(key_value_states)
|
| 315 |
+
value_states = self.value(key_value_states)
|
| 316 |
+
else:
|
| 317 |
+
# self_attention
|
| 318 |
+
key_states = self.key(hidden_states)
|
| 319 |
+
value_states = self.value(hidden_states)
|
| 320 |
+
|
| 321 |
+
query_states = self._split_heads(query_states)
|
| 322 |
+
key_states = self._split_heads(key_states)
|
| 323 |
+
value_states = self._split_heads(value_states)
|
| 324 |
+
|
| 325 |
+
# handle cache prepare causal attention mask
|
| 326 |
+
if self.causal:
|
| 327 |
+
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
| 328 |
+
if self.has_variable("cache", "cached_key"):
|
| 329 |
+
mask_shift = self.variables["cache"]["cache_index"]
|
| 330 |
+
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
| 331 |
+
causal_mask = lax.dynamic_slice(
|
| 332 |
+
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
| 333 |
+
)
|
| 334 |
+
else:
|
| 335 |
+
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
| 336 |
+
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
| 337 |
+
|
| 338 |
+
# combine masks if needed
|
| 339 |
+
if attention_mask is not None and self.causal:
|
| 340 |
+
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
| 341 |
+
attention_mask = combine_masks(attention_mask, causal_mask)
|
| 342 |
+
elif self.causal:
|
| 343 |
+
attention_mask = causal_mask
|
| 344 |
+
elif attention_mask is not None:
|
| 345 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
| 346 |
+
|
| 347 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
| 348 |
+
# and cache the keys and values step by step.
|
| 349 |
+
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
| 350 |
+
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
| 351 |
+
key_states, value_states, query_states, attention_mask
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Convert the boolean attention mask to an attention bias.
|
| 355 |
+
if attention_mask is not None:
|
| 356 |
+
# attention mask in the form of attention bias
|
| 357 |
+
attention_bias = lax.select(
|
| 358 |
+
attention_mask > 0,
|
| 359 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
| 360 |
+
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
attention_bias = None
|
| 364 |
+
|
| 365 |
+
dropout_rng = None
|
| 366 |
+
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
|
| 367 |
+
dropout_rng = self.make_rng("dropout")
|
| 368 |
+
|
| 369 |
+
attn_weights = dot_product_attention_weights(
|
| 370 |
+
query_states,
|
| 371 |
+
key_states,
|
| 372 |
+
bias=attention_bias,
|
| 373 |
+
dropout_rng=dropout_rng,
|
| 374 |
+
dropout_rate=self.config.attention_probs_dropout_prob,
|
| 375 |
+
broadcast_dropout=True,
|
| 376 |
+
deterministic=deterministic,
|
| 377 |
+
dtype=self.dtype,
|
| 378 |
+
precision=None,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
# Mask heads if we want to
|
| 382 |
+
if layer_head_mask is not None:
|
| 383 |
+
attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
|
| 384 |
+
|
| 385 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
| 386 |
+
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
| 387 |
+
|
| 388 |
+
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
| 389 |
+
return outputs
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class FlaxBertSelfOutput(nn.Module):
|
| 393 |
+
config: BertConfig
|
| 394 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 395 |
+
|
| 396 |
+
def setup(self):
|
| 397 |
+
self.dense = nn.Dense(
|
| 398 |
+
self.config.hidden_size,
|
| 399 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 400 |
+
dtype=self.dtype,
|
| 401 |
+
)
|
| 402 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 403 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 404 |
+
|
| 405 |
+
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
| 406 |
+
hidden_states = self.dense(hidden_states)
|
| 407 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 408 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 409 |
+
return hidden_states
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class FlaxBertAttention(nn.Module):
|
| 413 |
+
config: BertConfig
|
| 414 |
+
causal: bool = False
|
| 415 |
+
dtype: jnp.dtype = jnp.float32
|
| 416 |
+
|
| 417 |
+
def setup(self):
|
| 418 |
+
self.self = FlaxBertSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
|
| 419 |
+
self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
|
| 420 |
+
|
| 421 |
+
def __call__(
|
| 422 |
+
self,
|
| 423 |
+
hidden_states,
|
| 424 |
+
attention_mask,
|
| 425 |
+
layer_head_mask,
|
| 426 |
+
key_value_states=None,
|
| 427 |
+
init_cache=False,
|
| 428 |
+
deterministic=True,
|
| 429 |
+
output_attentions: bool = False,
|
| 430 |
+
):
|
| 431 |
+
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
| 432 |
+
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
| 433 |
+
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
| 434 |
+
attn_outputs = self.self(
|
| 435 |
+
hidden_states,
|
| 436 |
+
attention_mask,
|
| 437 |
+
layer_head_mask=layer_head_mask,
|
| 438 |
+
key_value_states=key_value_states,
|
| 439 |
+
init_cache=init_cache,
|
| 440 |
+
deterministic=deterministic,
|
| 441 |
+
output_attentions=output_attentions,
|
| 442 |
+
)
|
| 443 |
+
attn_output = attn_outputs[0]
|
| 444 |
+
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
| 445 |
+
|
| 446 |
+
outputs = (hidden_states,)
|
| 447 |
+
|
| 448 |
+
if output_attentions:
|
| 449 |
+
outputs += (attn_outputs[1],)
|
| 450 |
+
|
| 451 |
+
return outputs
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
class FlaxBertIntermediate(nn.Module):
|
| 455 |
+
config: BertConfig
|
| 456 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 457 |
+
|
| 458 |
+
def setup(self):
|
| 459 |
+
self.dense = nn.Dense(
|
| 460 |
+
self.config.intermediate_size,
|
| 461 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 462 |
+
dtype=self.dtype,
|
| 463 |
+
)
|
| 464 |
+
self.activation = ACT2FN[self.config.hidden_act]
|
| 465 |
+
|
| 466 |
+
def __call__(self, hidden_states):
|
| 467 |
+
hidden_states = self.dense(hidden_states)
|
| 468 |
+
hidden_states = self.activation(hidden_states)
|
| 469 |
+
return hidden_states
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class FlaxBertOutput(nn.Module):
|
| 473 |
+
config: BertConfig
|
| 474 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 475 |
+
|
| 476 |
+
def setup(self):
|
| 477 |
+
self.dense = nn.Dense(
|
| 478 |
+
self.config.hidden_size,
|
| 479 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 480 |
+
dtype=self.dtype,
|
| 481 |
+
)
|
| 482 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 483 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 484 |
+
|
| 485 |
+
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
| 486 |
+
hidden_states = self.dense(hidden_states)
|
| 487 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 488 |
+
hidden_states = self.LayerNorm(hidden_states + attention_output)
|
| 489 |
+
return hidden_states
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class FlaxBertLayer(nn.Module):
|
| 493 |
+
config: BertConfig
|
| 494 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 495 |
+
|
| 496 |
+
def setup(self):
|
| 497 |
+
self.attention = FlaxBertAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
|
| 498 |
+
self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
|
| 499 |
+
self.output = FlaxBertOutput(self.config, dtype=self.dtype)
|
| 500 |
+
if self.config.add_cross_attention:
|
| 501 |
+
self.crossattention = FlaxBertAttention(self.config, causal=False, dtype=self.dtype)
|
| 502 |
+
|
| 503 |
+
def __call__(
|
| 504 |
+
self,
|
| 505 |
+
hidden_states,
|
| 506 |
+
attention_mask,
|
| 507 |
+
layer_head_mask,
|
| 508 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
| 509 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
| 510 |
+
init_cache: bool = False,
|
| 511 |
+
deterministic: bool = True,
|
| 512 |
+
output_attentions: bool = False,
|
| 513 |
+
):
|
| 514 |
+
# Self Attention
|
| 515 |
+
attention_outputs = self.attention(
|
| 516 |
+
hidden_states,
|
| 517 |
+
attention_mask,
|
| 518 |
+
layer_head_mask=layer_head_mask,
|
| 519 |
+
init_cache=init_cache,
|
| 520 |
+
deterministic=deterministic,
|
| 521 |
+
output_attentions=output_attentions,
|
| 522 |
+
)
|
| 523 |
+
attention_output = attention_outputs[0]
|
| 524 |
+
|
| 525 |
+
# Cross-Attention Block
|
| 526 |
+
if encoder_hidden_states is not None:
|
| 527 |
+
cross_attention_outputs = self.crossattention(
|
| 528 |
+
attention_output,
|
| 529 |
+
attention_mask=encoder_attention_mask,
|
| 530 |
+
layer_head_mask=layer_head_mask,
|
| 531 |
+
key_value_states=encoder_hidden_states,
|
| 532 |
+
deterministic=deterministic,
|
| 533 |
+
output_attentions=output_attentions,
|
| 534 |
+
)
|
| 535 |
+
attention_output = cross_attention_outputs[0]
|
| 536 |
+
|
| 537 |
+
hidden_states = self.intermediate(attention_output)
|
| 538 |
+
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
| 539 |
+
|
| 540 |
+
outputs = (hidden_states,)
|
| 541 |
+
|
| 542 |
+
if output_attentions:
|
| 543 |
+
outputs += (attention_outputs[1],)
|
| 544 |
+
if encoder_hidden_states is not None:
|
| 545 |
+
outputs += (cross_attention_outputs[1],)
|
| 546 |
+
return outputs
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
class FlaxBertLayerCollection(nn.Module):
|
| 550 |
+
config: BertConfig
|
| 551 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 552 |
+
gradient_checkpointing: bool = False
|
| 553 |
+
|
| 554 |
+
def setup(self):
|
| 555 |
+
if self.gradient_checkpointing:
|
| 556 |
+
FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7))
|
| 557 |
+
self.layers = [
|
| 558 |
+
FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
|
| 559 |
+
for i in range(self.config.num_hidden_layers)
|
| 560 |
+
]
|
| 561 |
+
else:
|
| 562 |
+
self.layers = [
|
| 563 |
+
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
| 564 |
+
]
|
| 565 |
+
|
| 566 |
+
def __call__(
|
| 567 |
+
self,
|
| 568 |
+
hidden_states,
|
| 569 |
+
attention_mask,
|
| 570 |
+
head_mask,
|
| 571 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
| 572 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
| 573 |
+
init_cache: bool = False,
|
| 574 |
+
deterministic: bool = True,
|
| 575 |
+
output_attentions: bool = False,
|
| 576 |
+
output_hidden_states: bool = False,
|
| 577 |
+
return_dict: bool = True,
|
| 578 |
+
):
|
| 579 |
+
all_attentions = () if output_attentions else None
|
| 580 |
+
all_hidden_states = () if output_hidden_states else None
|
| 581 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
| 582 |
+
|
| 583 |
+
# Check if head_mask has a correct number of layers specified if desired
|
| 584 |
+
if head_mask is not None:
|
| 585 |
+
if head_mask.shape[0] != (len(self.layers)):
|
| 586 |
+
raise ValueError(
|
| 587 |
+
f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
|
| 588 |
+
f" {head_mask.shape[0]}."
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
for i, layer in enumerate(self.layers):
|
| 592 |
+
if output_hidden_states:
|
| 593 |
+
all_hidden_states += (hidden_states,)
|
| 594 |
+
|
| 595 |
+
layer_outputs = layer(
|
| 596 |
+
hidden_states,
|
| 597 |
+
attention_mask,
|
| 598 |
+
head_mask[i] if head_mask is not None else None,
|
| 599 |
+
encoder_hidden_states,
|
| 600 |
+
encoder_attention_mask,
|
| 601 |
+
init_cache,
|
| 602 |
+
deterministic,
|
| 603 |
+
output_attentions,
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
hidden_states = layer_outputs[0]
|
| 607 |
+
|
| 608 |
+
if output_attentions:
|
| 609 |
+
all_attentions += (layer_outputs[1],)
|
| 610 |
+
|
| 611 |
+
if encoder_hidden_states is not None:
|
| 612 |
+
all_cross_attentions += (layer_outputs[2],)
|
| 613 |
+
|
| 614 |
+
if output_hidden_states:
|
| 615 |
+
all_hidden_states += (hidden_states,)
|
| 616 |
+
|
| 617 |
+
outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
|
| 618 |
+
|
| 619 |
+
if not return_dict:
|
| 620 |
+
return tuple(v for v in outputs if v is not None)
|
| 621 |
+
|
| 622 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
| 623 |
+
last_hidden_state=hidden_states,
|
| 624 |
+
hidden_states=all_hidden_states,
|
| 625 |
+
attentions=all_attentions,
|
| 626 |
+
cross_attentions=all_cross_attentions,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
class FlaxBertEncoder(nn.Module):
|
| 631 |
+
config: BertConfig
|
| 632 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 633 |
+
gradient_checkpointing: bool = False
|
| 634 |
+
|
| 635 |
+
def setup(self):
|
| 636 |
+
self.layer = FlaxBertLayerCollection(
|
| 637 |
+
self.config,
|
| 638 |
+
dtype=self.dtype,
|
| 639 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
def __call__(
|
| 643 |
+
self,
|
| 644 |
+
hidden_states,
|
| 645 |
+
attention_mask,
|
| 646 |
+
head_mask,
|
| 647 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
| 648 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
| 649 |
+
init_cache: bool = False,
|
| 650 |
+
deterministic: bool = True,
|
| 651 |
+
output_attentions: bool = False,
|
| 652 |
+
output_hidden_states: bool = False,
|
| 653 |
+
return_dict: bool = True,
|
| 654 |
+
):
|
| 655 |
+
return self.layer(
|
| 656 |
+
hidden_states,
|
| 657 |
+
attention_mask,
|
| 658 |
+
head_mask=head_mask,
|
| 659 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 660 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 661 |
+
init_cache=init_cache,
|
| 662 |
+
deterministic=deterministic,
|
| 663 |
+
output_attentions=output_attentions,
|
| 664 |
+
output_hidden_states=output_hidden_states,
|
| 665 |
+
return_dict=return_dict,
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
class FlaxBertPooler(nn.Module):
|
| 670 |
+
config: BertConfig
|
| 671 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 672 |
+
|
| 673 |
+
def setup(self):
|
| 674 |
+
self.dense = nn.Dense(
|
| 675 |
+
self.config.hidden_size,
|
| 676 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 677 |
+
dtype=self.dtype,
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
def __call__(self, hidden_states):
|
| 681 |
+
cls_hidden_state = hidden_states[:, 0]
|
| 682 |
+
cls_hidden_state = self.dense(cls_hidden_state)
|
| 683 |
+
return nn.tanh(cls_hidden_state)
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
class FlaxBertPredictionHeadTransform(nn.Module):
|
| 687 |
+
config: BertConfig
|
| 688 |
+
dtype: jnp.dtype = jnp.float32
|
| 689 |
+
|
| 690 |
+
def setup(self):
|
| 691 |
+
self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
|
| 692 |
+
self.activation = ACT2FN[self.config.hidden_act]
|
| 693 |
+
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 694 |
+
|
| 695 |
+
def __call__(self, hidden_states):
|
| 696 |
+
hidden_states = self.dense(hidden_states)
|
| 697 |
+
hidden_states = self.activation(hidden_states)
|
| 698 |
+
return self.LayerNorm(hidden_states)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
class FlaxBertLMPredictionHead(nn.Module):
|
| 702 |
+
config: BertConfig
|
| 703 |
+
dtype: jnp.dtype = jnp.float32
|
| 704 |
+
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
| 705 |
+
|
| 706 |
+
def setup(self):
|
| 707 |
+
self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype)
|
| 708 |
+
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
|
| 709 |
+
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
|
| 710 |
+
|
| 711 |
+
def __call__(self, hidden_states, shared_embedding=None):
|
| 712 |
+
hidden_states = self.transform(hidden_states)
|
| 713 |
+
|
| 714 |
+
if shared_embedding is not None:
|
| 715 |
+
hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
| 716 |
+
else:
|
| 717 |
+
hidden_states = self.decoder(hidden_states)
|
| 718 |
+
|
| 719 |
+
bias = jnp.asarray(self.bias, self.dtype)
|
| 720 |
+
hidden_states += bias
|
| 721 |
+
return hidden_states
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
class FlaxBertOnlyMLMHead(nn.Module):
|
| 725 |
+
config: BertConfig
|
| 726 |
+
dtype: jnp.dtype = jnp.float32
|
| 727 |
+
|
| 728 |
+
def setup(self):
|
| 729 |
+
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
|
| 730 |
+
|
| 731 |
+
def __call__(self, hidden_states, shared_embedding=None):
|
| 732 |
+
hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding)
|
| 733 |
+
return hidden_states
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
class FlaxBertOnlyNSPHead(nn.Module):
|
| 737 |
+
dtype: jnp.dtype = jnp.float32
|
| 738 |
+
|
| 739 |
+
def setup(self):
|
| 740 |
+
self.seq_relationship = nn.Dense(2, dtype=self.dtype)
|
| 741 |
+
|
| 742 |
+
def __call__(self, pooled_output):
|
| 743 |
+
return self.seq_relationship(pooled_output)
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
class FlaxBertPreTrainingHeads(nn.Module):
|
| 747 |
+
config: BertConfig
|
| 748 |
+
dtype: jnp.dtype = jnp.float32
|
| 749 |
+
|
| 750 |
+
def setup(self):
|
| 751 |
+
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
|
| 752 |
+
self.seq_relationship = nn.Dense(2, dtype=self.dtype)
|
| 753 |
+
|
| 754 |
+
def __call__(self, hidden_states, pooled_output, shared_embedding=None):
|
| 755 |
+
prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
|
| 756 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 757 |
+
return prediction_scores, seq_relationship_score
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
| 761 |
+
"""
|
| 762 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 763 |
+
models.
|
| 764 |
+
"""
|
| 765 |
+
|
| 766 |
+
config_class = BertConfig
|
| 767 |
+
base_model_prefix = "bert"
|
| 768 |
+
module_class: nn.Module = None
|
| 769 |
+
|
| 770 |
+
def __init__(
|
| 771 |
+
self,
|
| 772 |
+
config: BertConfig,
|
| 773 |
+
input_shape: tuple = (1, 1),
|
| 774 |
+
seed: int = 0,
|
| 775 |
+
dtype: jnp.dtype = jnp.float32,
|
| 776 |
+
_do_init: bool = True,
|
| 777 |
+
gradient_checkpointing: bool = False,
|
| 778 |
+
**kwargs,
|
| 779 |
+
):
|
| 780 |
+
module = self.module_class(
|
| 781 |
+
config=config,
|
| 782 |
+
dtype=dtype,
|
| 783 |
+
gradient_checkpointing=gradient_checkpointing,
|
| 784 |
+
**kwargs,
|
| 785 |
+
)
|
| 786 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
| 787 |
+
|
| 788 |
+
def enable_gradient_checkpointing(self):
|
| 789 |
+
self._module = self.module_class(
|
| 790 |
+
config=self.config,
|
| 791 |
+
dtype=self.dtype,
|
| 792 |
+
gradient_checkpointing=True,
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> FrozenDict:
|
| 796 |
+
# init input tensors
|
| 797 |
+
input_ids = jnp.zeros(input_shape, dtype="i4")
|
| 798 |
+
token_type_ids = jnp.zeros_like(input_ids)
|
| 799 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
| 800 |
+
attention_mask = jnp.ones_like(input_ids)
|
| 801 |
+
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
| 802 |
+
|
| 803 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
| 804 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
| 805 |
+
|
| 806 |
+
if self.config.add_cross_attention:
|
| 807 |
+
encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
|
| 808 |
+
encoder_attention_mask = attention_mask
|
| 809 |
+
module_init_outputs = self.module.init(
|
| 810 |
+
rngs,
|
| 811 |
+
input_ids,
|
| 812 |
+
attention_mask,
|
| 813 |
+
token_type_ids,
|
| 814 |
+
position_ids,
|
| 815 |
+
head_mask,
|
| 816 |
+
encoder_hidden_states,
|
| 817 |
+
encoder_attention_mask,
|
| 818 |
+
return_dict=False,
|
| 819 |
+
)
|
| 820 |
+
else:
|
| 821 |
+
module_init_outputs = self.module.init(
|
| 822 |
+
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
random_params = module_init_outputs["params"]
|
| 826 |
+
|
| 827 |
+
if params is not None:
|
| 828 |
+
random_params = flatten_dict(unfreeze(random_params))
|
| 829 |
+
params = flatten_dict(unfreeze(params))
|
| 830 |
+
for missing_key in self._missing_keys:
|
| 831 |
+
params[missing_key] = random_params[missing_key]
|
| 832 |
+
self._missing_keys = set()
|
| 833 |
+
return freeze(unflatten_dict(params))
|
| 834 |
+
else:
|
| 835 |
+
return random_params
|
| 836 |
+
|
| 837 |
+
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
|
| 838 |
+
def init_cache(self, batch_size, max_length):
|
| 839 |
+
r"""
|
| 840 |
+
Args:
|
| 841 |
+
batch_size (`int`):
|
| 842 |
+
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
| 843 |
+
max_length (`int`):
|
| 844 |
+
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
| 845 |
+
cache.
|
| 846 |
+
"""
|
| 847 |
+
# init input variables to retrieve cache
|
| 848 |
+
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
| 849 |
+
attention_mask = jnp.ones_like(input_ids, dtype="i4")
|
| 850 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
| 851 |
+
|
| 852 |
+
init_variables = self.module.init(
|
| 853 |
+
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
|
| 854 |
+
)
|
| 855 |
+
return unfreeze(init_variables["cache"])
|
| 856 |
+
|
| 857 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 858 |
+
def __call__(
|
| 859 |
+
self,
|
| 860 |
+
input_ids,
|
| 861 |
+
attention_mask=None,
|
| 862 |
+
token_type_ids=None,
|
| 863 |
+
position_ids=None,
|
| 864 |
+
head_mask=None,
|
| 865 |
+
encoder_hidden_states=None,
|
| 866 |
+
encoder_attention_mask=None,
|
| 867 |
+
params: Optional[dict] = None,
|
| 868 |
+
dropout_rng: jax.random.PRNGKey = None,
|
| 869 |
+
train: bool = False,
|
| 870 |
+
output_attentions: Optional[bool] = None,
|
| 871 |
+
output_hidden_states: Optional[bool] = None,
|
| 872 |
+
return_dict: Optional[bool] = None,
|
| 873 |
+
past_key_values: Optional[dict] = None,
|
| 874 |
+
):
|
| 875 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 876 |
+
output_hidden_states = (
|
| 877 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 878 |
+
)
|
| 879 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 880 |
+
|
| 881 |
+
# init input tensors if not passed
|
| 882 |
+
if token_type_ids is None:
|
| 883 |
+
token_type_ids = jnp.zeros_like(input_ids)
|
| 884 |
+
|
| 885 |
+
if position_ids is None:
|
| 886 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
| 887 |
+
|
| 888 |
+
if attention_mask is None:
|
| 889 |
+
attention_mask = jnp.ones_like(input_ids)
|
| 890 |
+
|
| 891 |
+
if head_mask is None:
|
| 892 |
+
head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
|
| 893 |
+
|
| 894 |
+
# Handle any PRNG if needed
|
| 895 |
+
rngs = {}
|
| 896 |
+
if dropout_rng is not None:
|
| 897 |
+
rngs["dropout"] = dropout_rng
|
| 898 |
+
|
| 899 |
+
inputs = {"params": params or self.params}
|
| 900 |
+
|
| 901 |
+
if self.config.add_cross_attention:
|
| 902 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
|
| 903 |
+
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
|
| 904 |
+
# changed by FlaxBertAttention module
|
| 905 |
+
if past_key_values:
|
| 906 |
+
inputs["cache"] = past_key_values
|
| 907 |
+
mutable = ["cache"]
|
| 908 |
+
else:
|
| 909 |
+
mutable = False
|
| 910 |
+
|
| 911 |
+
outputs = self.module.apply(
|
| 912 |
+
inputs,
|
| 913 |
+
jnp.array(input_ids, dtype="i4"),
|
| 914 |
+
jnp.array(attention_mask, dtype="i4"),
|
| 915 |
+
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
| 916 |
+
position_ids=jnp.array(position_ids, dtype="i4"),
|
| 917 |
+
head_mask=jnp.array(head_mask, dtype="i4"),
|
| 918 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 919 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 920 |
+
deterministic=not train,
|
| 921 |
+
output_attentions=output_attentions,
|
| 922 |
+
output_hidden_states=output_hidden_states,
|
| 923 |
+
return_dict=return_dict,
|
| 924 |
+
rngs=rngs,
|
| 925 |
+
mutable=mutable,
|
| 926 |
+
)
|
| 927 |
+
|
| 928 |
+
# add updated cache to model output
|
| 929 |
+
if past_key_values is not None and return_dict:
|
| 930 |
+
outputs, past_key_values = outputs
|
| 931 |
+
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
| 932 |
+
return outputs
|
| 933 |
+
elif past_key_values is not None and not return_dict:
|
| 934 |
+
outputs, past_key_values = outputs
|
| 935 |
+
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
| 936 |
+
|
| 937 |
+
else:
|
| 938 |
+
outputs = self.module.apply(
|
| 939 |
+
inputs,
|
| 940 |
+
jnp.array(input_ids, dtype="i4"),
|
| 941 |
+
jnp.array(attention_mask, dtype="i4"),
|
| 942 |
+
token_type_ids=jnp.array(token_type_ids, dtype="i4"),
|
| 943 |
+
position_ids=jnp.array(position_ids, dtype="i4"),
|
| 944 |
+
head_mask=jnp.array(head_mask, dtype="i4"),
|
| 945 |
+
deterministic=not train,
|
| 946 |
+
output_attentions=output_attentions,
|
| 947 |
+
output_hidden_states=output_hidden_states,
|
| 948 |
+
return_dict=return_dict,
|
| 949 |
+
rngs=rngs,
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
return outputs
|
| 953 |
+
|
| 954 |
+
|
| 955 |
+
class FlaxBertModule(nn.Module):
|
| 956 |
+
config: BertConfig
|
| 957 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 958 |
+
add_pooling_layer: bool = True
|
| 959 |
+
gradient_checkpointing: bool = False
|
| 960 |
+
|
| 961 |
+
def setup(self):
|
| 962 |
+
self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)
|
| 963 |
+
self.encoder = FlaxBertEncoder(
|
| 964 |
+
self.config,
|
| 965 |
+
dtype=self.dtype,
|
| 966 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
| 967 |
+
)
|
| 968 |
+
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
|
| 969 |
+
|
| 970 |
+
def __call__(
|
| 971 |
+
self,
|
| 972 |
+
input_ids,
|
| 973 |
+
attention_mask,
|
| 974 |
+
token_type_ids: Optional[jnp.ndarray] = None,
|
| 975 |
+
position_ids: Optional[jnp.ndarray] = None,
|
| 976 |
+
head_mask: Optional[jnp.ndarray] = None,
|
| 977 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
| 978 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
| 979 |
+
init_cache: bool = False,
|
| 980 |
+
deterministic: bool = True,
|
| 981 |
+
output_attentions: bool = False,
|
| 982 |
+
output_hidden_states: bool = False,
|
| 983 |
+
return_dict: bool = True,
|
| 984 |
+
):
|
| 985 |
+
# make sure `token_type_ids` is correctly initialized when not passed
|
| 986 |
+
if token_type_ids is None:
|
| 987 |
+
token_type_ids = jnp.zeros_like(input_ids)
|
| 988 |
+
|
| 989 |
+
# make sure `position_ids` is correctly initialized when not passed
|
| 990 |
+
if position_ids is None:
|
| 991 |
+
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
| 992 |
+
|
| 993 |
+
hidden_states = self.embeddings(
|
| 994 |
+
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
| 995 |
+
)
|
| 996 |
+
outputs = self.encoder(
|
| 997 |
+
hidden_states,
|
| 998 |
+
attention_mask,
|
| 999 |
+
head_mask=head_mask,
|
| 1000 |
+
deterministic=deterministic,
|
| 1001 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1002 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1003 |
+
init_cache=init_cache,
|
| 1004 |
+
output_attentions=output_attentions,
|
| 1005 |
+
output_hidden_states=output_hidden_states,
|
| 1006 |
+
return_dict=return_dict,
|
| 1007 |
+
)
|
| 1008 |
+
hidden_states = outputs[0]
|
| 1009 |
+
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
|
| 1010 |
+
|
| 1011 |
+
if not return_dict:
|
| 1012 |
+
# if pooled is None, don't return it
|
| 1013 |
+
if pooled is None:
|
| 1014 |
+
return (hidden_states,) + outputs[1:]
|
| 1015 |
+
return (hidden_states, pooled) + outputs[1:]
|
| 1016 |
+
|
| 1017 |
+
return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
|
| 1018 |
+
last_hidden_state=hidden_states,
|
| 1019 |
+
pooler_output=pooled,
|
| 1020 |
+
hidden_states=outputs.hidden_states,
|
| 1021 |
+
attentions=outputs.attentions,
|
| 1022 |
+
cross_attentions=outputs.cross_attentions,
|
| 1023 |
+
)
|
| 1024 |
+
|
| 1025 |
+
|
| 1026 |
+
@add_start_docstrings(
|
| 1027 |
+
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
| 1028 |
+
BERT_START_DOCSTRING,
|
| 1029 |
+
)
|
| 1030 |
+
class FlaxBertModel(FlaxBertPreTrainedModel):
|
| 1031 |
+
module_class = FlaxBertModule
|
| 1032 |
+
|
| 1033 |
+
|
| 1034 |
+
append_call_sample_docstring(FlaxBertModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)
|
| 1035 |
+
|
| 1036 |
+
|
| 1037 |
+
class FlaxBertForPreTrainingModule(nn.Module):
|
| 1038 |
+
config: BertConfig
|
| 1039 |
+
dtype: jnp.dtype = jnp.float32
|
| 1040 |
+
gradient_checkpointing: bool = False
|
| 1041 |
+
|
| 1042 |
+
def setup(self):
|
| 1043 |
+
self.bert = FlaxBertModule(
|
| 1044 |
+
config=self.config,
|
| 1045 |
+
dtype=self.dtype,
|
| 1046 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
| 1047 |
+
)
|
| 1048 |
+
self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)
|
| 1049 |
+
|
| 1050 |
+
def __call__(
|
| 1051 |
+
self,
|
| 1052 |
+
input_ids,
|
| 1053 |
+
attention_mask,
|
| 1054 |
+
token_type_ids,
|
| 1055 |
+
position_ids,
|
| 1056 |
+
head_mask,
|
| 1057 |
+
deterministic: bool = True,
|
| 1058 |
+
output_attentions: bool = False,
|
| 1059 |
+
output_hidden_states: bool = False,
|
| 1060 |
+
return_dict: bool = True,
|
| 1061 |
+
):
|
| 1062 |
+
# Model
|
| 1063 |
+
outputs = self.bert(
|
| 1064 |
+
input_ids,
|
| 1065 |
+
attention_mask,
|
| 1066 |
+
token_type_ids,
|
| 1067 |
+
position_ids,
|
| 1068 |
+
head_mask,
|
| 1069 |
+
deterministic=deterministic,
|
| 1070 |
+
output_attentions=output_attentions,
|
| 1071 |
+
output_hidden_states=output_hidden_states,
|
| 1072 |
+
return_dict=return_dict,
|
| 1073 |
+
)
|
| 1074 |
+
|
| 1075 |
+
if self.config.tie_word_embeddings:
|
| 1076 |
+
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
| 1077 |
+
else:
|
| 1078 |
+
shared_embedding = None
|
| 1079 |
+
|
| 1080 |
+
hidden_states = outputs[0]
|
| 1081 |
+
pooled_output = outputs[1]
|
| 1082 |
+
|
| 1083 |
+
prediction_scores, seq_relationship_score = self.cls(
|
| 1084 |
+
hidden_states, pooled_output, shared_embedding=shared_embedding
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
if not return_dict:
|
| 1088 |
+
return (prediction_scores, seq_relationship_score) + outputs[2:]
|
| 1089 |
+
|
| 1090 |
+
return FlaxBertForPreTrainingOutput(
|
| 1091 |
+
prediction_logits=prediction_scores,
|
| 1092 |
+
seq_relationship_logits=seq_relationship_score,
|
| 1093 |
+
hidden_states=outputs.hidden_states,
|
| 1094 |
+
attentions=outputs.attentions,
|
| 1095 |
+
)
|
| 1096 |
+
|
| 1097 |
+
|
| 1098 |
+
@add_start_docstrings(
|
| 1099 |
+
"""
|
| 1100 |
+
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
|
| 1101 |
+
sentence prediction (classification)` head.
|
| 1102 |
+
""",
|
| 1103 |
+
BERT_START_DOCSTRING,
|
| 1104 |
+
)
|
| 1105 |
+
class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
|
| 1106 |
+
module_class = FlaxBertForPreTrainingModule
|
| 1107 |
+
|
| 1108 |
+
|
| 1109 |
+
FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """
|
| 1110 |
+
Returns:
|
| 1111 |
+
|
| 1112 |
+
Example:
|
| 1113 |
+
|
| 1114 |
+
```python
|
| 1115 |
+
>>> from transformers import AutoTokenizer, FlaxBertForPreTraining
|
| 1116 |
+
|
| 1117 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
| 1118 |
+
>>> model = FlaxBertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
|
| 1119 |
+
|
| 1120 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
|
| 1121 |
+
>>> outputs = model(**inputs)
|
| 1122 |
+
|
| 1123 |
+
>>> prediction_logits = outputs.prediction_logits
|
| 1124 |
+
>>> seq_relationship_logits = outputs.seq_relationship_logits
|
| 1125 |
+
```
|
| 1126 |
+
"""
|
| 1127 |
+
|
| 1128 |
+
overwrite_call_docstring(
|
| 1129 |
+
FlaxBertForPreTraining,
|
| 1130 |
+
BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_PRETRAINING_DOCSTRING,
|
| 1131 |
+
)
|
| 1132 |
+
append_replace_return_docstrings(
|
| 1133 |
+
FlaxBertForPreTraining, output_type=FlaxBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
|
| 1134 |
+
)
|
| 1135 |
+
|
| 1136 |
+
|
| 1137 |
+
class FlaxBertForMaskedLMModule(nn.Module):
|
| 1138 |
+
config: BertConfig
|
| 1139 |
+
dtype: jnp.dtype = jnp.float32
|
| 1140 |
+
gradient_checkpointing: bool = False
|
| 1141 |
+
|
| 1142 |
+
def setup(self):
|
| 1143 |
+
self.bert = FlaxBertModule(
|
| 1144 |
+
config=self.config,
|
| 1145 |
+
add_pooling_layer=False,
|
| 1146 |
+
dtype=self.dtype,
|
| 1147 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
| 1148 |
+
)
|
| 1149 |
+
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
|
| 1150 |
+
|
| 1151 |
+
def __call__(
|
| 1152 |
+
self,
|
| 1153 |
+
input_ids,
|
| 1154 |
+
attention_mask,
|
| 1155 |
+
token_type_ids,
|
| 1156 |
+
position_ids,
|
| 1157 |
+
head_mask,
|
| 1158 |
+
deterministic: bool = True,
|
| 1159 |
+
output_attentions: bool = False,
|
| 1160 |
+
output_hidden_states: bool = False,
|
| 1161 |
+
return_dict: bool = True,
|
| 1162 |
+
):
|
| 1163 |
+
# Model
|
| 1164 |
+
outputs = self.bert(
|
| 1165 |
+
input_ids,
|
| 1166 |
+
attention_mask,
|
| 1167 |
+
token_type_ids,
|
| 1168 |
+
position_ids,
|
| 1169 |
+
head_mask,
|
| 1170 |
+
deterministic=deterministic,
|
| 1171 |
+
output_attentions=output_attentions,
|
| 1172 |
+
output_hidden_states=output_hidden_states,
|
| 1173 |
+
return_dict=return_dict,
|
| 1174 |
+
)
|
| 1175 |
+
|
| 1176 |
+
hidden_states = outputs[0]
|
| 1177 |
+
if self.config.tie_word_embeddings:
|
| 1178 |
+
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
| 1179 |
+
else:
|
| 1180 |
+
shared_embedding = None
|
| 1181 |
+
|
| 1182 |
+
# Compute the prediction scores
|
| 1183 |
+
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
|
| 1184 |
+
|
| 1185 |
+
if not return_dict:
|
| 1186 |
+
return (logits,) + outputs[1:]
|
| 1187 |
+
|
| 1188 |
+
return FlaxMaskedLMOutput(
|
| 1189 |
+
logits=logits,
|
| 1190 |
+
hidden_states=outputs.hidden_states,
|
| 1191 |
+
attentions=outputs.attentions,
|
| 1192 |
+
)
|
| 1193 |
+
|
| 1194 |
+
|
| 1195 |
+
@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
|
| 1196 |
+
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
| 1197 |
+
module_class = FlaxBertForMaskedLMModule
|
| 1198 |
+
|
| 1199 |
+
|
| 1200 |
+
append_call_sample_docstring(FlaxBertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC)
|
| 1201 |
+
|
| 1202 |
+
|
| 1203 |
+
class FlaxBertForNextSentencePredictionModule(nn.Module):
|
| 1204 |
+
config: BertConfig
|
| 1205 |
+
dtype: jnp.dtype = jnp.float32
|
| 1206 |
+
gradient_checkpointing: bool = False
|
| 1207 |
+
|
| 1208 |
+
def setup(self):
|
| 1209 |
+
self.bert = FlaxBertModule(
|
| 1210 |
+
config=self.config,
|
| 1211 |
+
dtype=self.dtype,
|
| 1212 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
| 1213 |
+
)
|
| 1214 |
+
self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)
|
| 1215 |
+
|
| 1216 |
+
def __call__(
|
| 1217 |
+
self,
|
| 1218 |
+
input_ids,
|
| 1219 |
+
attention_mask,
|
| 1220 |
+
token_type_ids,
|
| 1221 |
+
position_ids,
|
| 1222 |
+
head_mask,
|
| 1223 |
+
deterministic: bool = True,
|
| 1224 |
+
output_attentions: bool = False,
|
| 1225 |
+
output_hidden_states: bool = False,
|
| 1226 |
+
return_dict: bool = True,
|
| 1227 |
+
):
|
| 1228 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 1229 |
+
|
| 1230 |
+
# Model
|
| 1231 |
+
outputs = self.bert(
|
| 1232 |
+
input_ids,
|
| 1233 |
+
attention_mask,
|
| 1234 |
+
token_type_ids,
|
| 1235 |
+
position_ids,
|
| 1236 |
+
head_mask,
|
| 1237 |
+
deterministic=deterministic,
|
| 1238 |
+
output_attentions=output_attentions,
|
| 1239 |
+
output_hidden_states=output_hidden_states,
|
| 1240 |
+
return_dict=return_dict,
|
| 1241 |
+
)
|
| 1242 |
+
|
| 1243 |
+
pooled_output = outputs[1]
|
| 1244 |
+
seq_relationship_scores = self.cls(pooled_output)
|
| 1245 |
+
|
| 1246 |
+
if not return_dict:
|
| 1247 |
+
return (seq_relationship_scores,) + outputs[2:]
|
| 1248 |
+
|
| 1249 |
+
return FlaxNextSentencePredictorOutput(
|
| 1250 |
+
logits=seq_relationship_scores,
|
| 1251 |
+
hidden_states=outputs.hidden_states,
|
| 1252 |
+
attentions=outputs.attentions,
|
| 1253 |
+
)
|
| 1254 |
+
|
| 1255 |
+
|
| 1256 |
+
@add_start_docstrings(
|
| 1257 |
+
"""Bert Model with a `next sentence prediction (classification)` head on top.""",
|
| 1258 |
+
BERT_START_DOCSTRING,
|
| 1259 |
+
)
|
| 1260 |
+
class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
|
| 1261 |
+
module_class = FlaxBertForNextSentencePredictionModule
|
| 1262 |
+
|
| 1263 |
+
|
| 1264 |
+
FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING = """
|
| 1265 |
+
Returns:
|
| 1266 |
+
|
| 1267 |
+
Example:
|
| 1268 |
+
|
| 1269 |
+
```python
|
| 1270 |
+
>>> from transformers import AutoTokenizer, FlaxBertForNextSentencePrediction
|
| 1271 |
+
|
| 1272 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
| 1273 |
+
>>> model = FlaxBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
|
| 1274 |
+
|
| 1275 |
+
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
| 1276 |
+
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
| 1277 |
+
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="jax")
|
| 1278 |
+
|
| 1279 |
+
>>> outputs = model(**encoding)
|
| 1280 |
+
>>> logits = outputs.logits
|
| 1281 |
+
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
| 1282 |
+
```
|
| 1283 |
+
"""
|
| 1284 |
+
|
| 1285 |
+
|
| 1286 |
+
overwrite_call_docstring(
|
| 1287 |
+
FlaxBertForNextSentencePrediction,
|
| 1288 |
+
BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING,
|
| 1289 |
+
)
|
| 1290 |
+
append_replace_return_docstrings(
|
| 1291 |
+
FlaxBertForNextSentencePrediction, output_type=FlaxNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
|
| 1292 |
+
)
|
| 1293 |
+
|
| 1294 |
+
|
| 1295 |
+
class FlaxBertForSequenceClassificationModule(nn.Module):
|
| 1296 |
+
config: BertConfig
|
| 1297 |
+
dtype: jnp.dtype = jnp.float32
|
| 1298 |
+
gradient_checkpointing: bool = False
|
| 1299 |
+
|
| 1300 |
+
def setup(self):
|
| 1301 |
+
self.bert = FlaxBertModule(
|
| 1302 |
+
config=self.config,
|
| 1303 |
+
dtype=self.dtype,
|
| 1304 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
| 1305 |
+
)
|
| 1306 |
+
classifier_dropout = (
|
| 1307 |
+
self.config.classifier_dropout
|
| 1308 |
+
if self.config.classifier_dropout is not None
|
| 1309 |
+
else self.config.hidden_dropout_prob
|
| 1310 |
+
)
|
| 1311 |
+
self.dropout = nn.Dropout(rate=classifier_dropout)
|
| 1312 |
+
self.classifier = nn.Dense(
|
| 1313 |
+
self.config.num_labels,
|
| 1314 |
+
dtype=self.dtype,
|
| 1315 |
+
)
|
| 1316 |
+
|
| 1317 |
+
def __call__(
|
| 1318 |
+
self,
|
| 1319 |
+
input_ids,
|
| 1320 |
+
attention_mask,
|
| 1321 |
+
token_type_ids,
|
| 1322 |
+
position_ids,
|
| 1323 |
+
head_mask,
|
| 1324 |
+
deterministic: bool = True,
|
| 1325 |
+
output_attentions: bool = False,
|
| 1326 |
+
output_hidden_states: bool = False,
|
| 1327 |
+
return_dict: bool = True,
|
| 1328 |
+
):
|
| 1329 |
+
# Model
|
| 1330 |
+
outputs = self.bert(
|
| 1331 |
+
input_ids,
|
| 1332 |
+
attention_mask,
|
| 1333 |
+
token_type_ids,
|
| 1334 |
+
position_ids,
|
| 1335 |
+
head_mask,
|
| 1336 |
+
deterministic=deterministic,
|
| 1337 |
+
output_attentions=output_attentions,
|
| 1338 |
+
output_hidden_states=output_hidden_states,
|
| 1339 |
+
return_dict=return_dict,
|
| 1340 |
+
)
|
| 1341 |
+
|
| 1342 |
+
pooled_output = outputs[1]
|
| 1343 |
+
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
| 1344 |
+
logits = self.classifier(pooled_output)
|
| 1345 |
+
|
| 1346 |
+
if not return_dict:
|
| 1347 |
+
return (logits,) + outputs[2:]
|
| 1348 |
+
|
| 1349 |
+
return FlaxSequenceClassifierOutput(
|
| 1350 |
+
logits=logits,
|
| 1351 |
+
hidden_states=outputs.hidden_states,
|
| 1352 |
+
attentions=outputs.attentions,
|
| 1353 |
+
)
|
| 1354 |
+
|
| 1355 |
+
|
| 1356 |
+
@add_start_docstrings(
|
| 1357 |
+
"""
|
| 1358 |
+
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
| 1359 |
+
output) e.g. for GLUE tasks.
|
| 1360 |
+
""",
|
| 1361 |
+
BERT_START_DOCSTRING,
|
| 1362 |
+
)
|
| 1363 |
+
class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
|
| 1364 |
+
module_class = FlaxBertForSequenceClassificationModule
|
| 1365 |
+
|
| 1366 |
+
|
| 1367 |
+
append_call_sample_docstring(
|
| 1368 |
+
FlaxBertForSequenceClassification,
|
| 1369 |
+
_CHECKPOINT_FOR_DOC,
|
| 1370 |
+
FlaxSequenceClassifierOutput,
|
| 1371 |
+
_CONFIG_FOR_DOC,
|
| 1372 |
+
)
|
| 1373 |
+
|
| 1374 |
+
|
| 1375 |
+
class FlaxBertForMultipleChoiceModule(nn.Module):
|
| 1376 |
+
config: BertConfig
|
| 1377 |
+
dtype: jnp.dtype = jnp.float32
|
| 1378 |
+
gradient_checkpointing: bool = False
|
| 1379 |
+
|
| 1380 |
+
def setup(self):
|
| 1381 |
+
self.bert = FlaxBertModule(
|
| 1382 |
+
config=self.config,
|
| 1383 |
+
dtype=self.dtype,
|
| 1384 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
| 1385 |
+
)
|
| 1386 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 1387 |
+
self.classifier = nn.Dense(1, dtype=self.dtype)
|
| 1388 |
+
|
| 1389 |
+
def __call__(
|
| 1390 |
+
self,
|
| 1391 |
+
input_ids,
|
| 1392 |
+
attention_mask,
|
| 1393 |
+
token_type_ids,
|
| 1394 |
+
position_ids,
|
| 1395 |
+
head_mask,
|
| 1396 |
+
deterministic: bool = True,
|
| 1397 |
+
output_attentions: bool = False,
|
| 1398 |
+
output_hidden_states: bool = False,
|
| 1399 |
+
return_dict: bool = True,
|
| 1400 |
+
):
|
| 1401 |
+
num_choices = input_ids.shape[1]
|
| 1402 |
+
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
|
| 1403 |
+
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
|
| 1404 |
+
token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
|
| 1405 |
+
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
|
| 1406 |
+
|
| 1407 |
+
# Model
|
| 1408 |
+
outputs = self.bert(
|
| 1409 |
+
input_ids,
|
| 1410 |
+
attention_mask,
|
| 1411 |
+
token_type_ids,
|
| 1412 |
+
position_ids,
|
| 1413 |
+
head_mask,
|
| 1414 |
+
deterministic=deterministic,
|
| 1415 |
+
output_attentions=output_attentions,
|
| 1416 |
+
output_hidden_states=output_hidden_states,
|
| 1417 |
+
return_dict=return_dict,
|
| 1418 |
+
)
|
| 1419 |
+
|
| 1420 |
+
pooled_output = outputs[1]
|
| 1421 |
+
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
| 1422 |
+
logits = self.classifier(pooled_output)
|
| 1423 |
+
|
| 1424 |
+
reshaped_logits = logits.reshape(-1, num_choices)
|
| 1425 |
+
|
| 1426 |
+
if not return_dict:
|
| 1427 |
+
return (reshaped_logits,) + outputs[2:]
|
| 1428 |
+
|
| 1429 |
+
return FlaxMultipleChoiceModelOutput(
|
| 1430 |
+
logits=reshaped_logits,
|
| 1431 |
+
hidden_states=outputs.hidden_states,
|
| 1432 |
+
attentions=outputs.attentions,
|
| 1433 |
+
)
|
| 1434 |
+
|
| 1435 |
+
|
| 1436 |
+
@add_start_docstrings(
|
| 1437 |
+
"""
|
| 1438 |
+
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
| 1439 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
| 1440 |
+
""",
|
| 1441 |
+
BERT_START_DOCSTRING,
|
| 1442 |
+
)
|
| 1443 |
+
class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
|
| 1444 |
+
module_class = FlaxBertForMultipleChoiceModule
|
| 1445 |
+
|
| 1446 |
+
|
| 1447 |
+
overwrite_call_docstring(
|
| 1448 |
+
FlaxBertForMultipleChoice, BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
| 1449 |
+
)
|
| 1450 |
+
append_call_sample_docstring(
|
| 1451 |
+
FlaxBertForMultipleChoice, _CHECKPOINT_FOR_DOC, FlaxMultipleChoiceModelOutput, _CONFIG_FOR_DOC
|
| 1452 |
+
)
|
| 1453 |
+
|
| 1454 |
+
|
| 1455 |
+
class FlaxBertForTokenClassificationModule(nn.Module):
|
| 1456 |
+
config: BertConfig
|
| 1457 |
+
dtype: jnp.dtype = jnp.float32
|
| 1458 |
+
gradient_checkpointing: bool = False
|
| 1459 |
+
|
| 1460 |
+
def setup(self):
|
| 1461 |
+
self.bert = FlaxBertModule(
|
| 1462 |
+
config=self.config,
|
| 1463 |
+
dtype=self.dtype,
|
| 1464 |
+
add_pooling_layer=False,
|
| 1465 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
| 1466 |
+
)
|
| 1467 |
+
classifier_dropout = (
|
| 1468 |
+
self.config.classifier_dropout
|
| 1469 |
+
if self.config.classifier_dropout is not None
|
| 1470 |
+
else self.config.hidden_dropout_prob
|
| 1471 |
+
)
|
| 1472 |
+
self.dropout = nn.Dropout(rate=classifier_dropout)
|
| 1473 |
+
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
| 1474 |
+
|
| 1475 |
+
def __call__(
|
| 1476 |
+
self,
|
| 1477 |
+
input_ids,
|
| 1478 |
+
attention_mask,
|
| 1479 |
+
token_type_ids,
|
| 1480 |
+
position_ids,
|
| 1481 |
+
head_mask,
|
| 1482 |
+
deterministic: bool = True,
|
| 1483 |
+
output_attentions: bool = False,
|
| 1484 |
+
output_hidden_states: bool = False,
|
| 1485 |
+
return_dict: bool = True,
|
| 1486 |
+
):
|
| 1487 |
+
# Model
|
| 1488 |
+
outputs = self.bert(
|
| 1489 |
+
input_ids,
|
| 1490 |
+
attention_mask,
|
| 1491 |
+
token_type_ids,
|
| 1492 |
+
position_ids,
|
| 1493 |
+
head_mask,
|
| 1494 |
+
deterministic=deterministic,
|
| 1495 |
+
output_attentions=output_attentions,
|
| 1496 |
+
output_hidden_states=output_hidden_states,
|
| 1497 |
+
return_dict=return_dict,
|
| 1498 |
+
)
|
| 1499 |
+
|
| 1500 |
+
hidden_states = outputs[0]
|
| 1501 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 1502 |
+
logits = self.classifier(hidden_states)
|
| 1503 |
+
|
| 1504 |
+
if not return_dict:
|
| 1505 |
+
return (logits,) + outputs[1:]
|
| 1506 |
+
|
| 1507 |
+
return FlaxTokenClassifierOutput(
|
| 1508 |
+
logits=logits,
|
| 1509 |
+
hidden_states=outputs.hidden_states,
|
| 1510 |
+
attentions=outputs.attentions,
|
| 1511 |
+
)
|
| 1512 |
+
|
| 1513 |
+
|
| 1514 |
+
@add_start_docstrings(
|
| 1515 |
+
"""
|
| 1516 |
+
Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
| 1517 |
+
Named-Entity-Recognition (NER) tasks.
|
| 1518 |
+
""",
|
| 1519 |
+
BERT_START_DOCSTRING,
|
| 1520 |
+
)
|
| 1521 |
+
class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
|
| 1522 |
+
module_class = FlaxBertForTokenClassificationModule
|
| 1523 |
+
|
| 1524 |
+
|
| 1525 |
+
append_call_sample_docstring(
|
| 1526 |
+
FlaxBertForTokenClassification, _CHECKPOINT_FOR_DOC, FlaxTokenClassifierOutput, _CONFIG_FOR_DOC
|
| 1527 |
+
)
|
| 1528 |
+
|
| 1529 |
+
|
| 1530 |
+
class FlaxBertForQuestionAnsweringModule(nn.Module):
|
| 1531 |
+
config: BertConfig
|
| 1532 |
+
dtype: jnp.dtype = jnp.float32
|
| 1533 |
+
gradient_checkpointing: bool = False
|
| 1534 |
+
|
| 1535 |
+
def setup(self):
|
| 1536 |
+
self.bert = FlaxBertModule(
|
| 1537 |
+
config=self.config,
|
| 1538 |
+
dtype=self.dtype,
|
| 1539 |
+
add_pooling_layer=False,
|
| 1540 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
| 1541 |
+
)
|
| 1542 |
+
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
| 1543 |
+
|
| 1544 |
+
def __call__(
|
| 1545 |
+
self,
|
| 1546 |
+
input_ids,
|
| 1547 |
+
attention_mask,
|
| 1548 |
+
token_type_ids,
|
| 1549 |
+
position_ids,
|
| 1550 |
+
head_mask,
|
| 1551 |
+
deterministic: bool = True,
|
| 1552 |
+
output_attentions: bool = False,
|
| 1553 |
+
output_hidden_states: bool = False,
|
| 1554 |
+
return_dict: bool = True,
|
| 1555 |
+
):
|
| 1556 |
+
# Model
|
| 1557 |
+
outputs = self.bert(
|
| 1558 |
+
input_ids,
|
| 1559 |
+
attention_mask,
|
| 1560 |
+
token_type_ids,
|
| 1561 |
+
position_ids,
|
| 1562 |
+
head_mask,
|
| 1563 |
+
deterministic=deterministic,
|
| 1564 |
+
output_attentions=output_attentions,
|
| 1565 |
+
output_hidden_states=output_hidden_states,
|
| 1566 |
+
return_dict=return_dict,
|
| 1567 |
+
)
|
| 1568 |
+
|
| 1569 |
+
hidden_states = outputs[0]
|
| 1570 |
+
|
| 1571 |
+
logits = self.qa_outputs(hidden_states)
|
| 1572 |
+
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
|
| 1573 |
+
start_logits = start_logits.squeeze(-1)
|
| 1574 |
+
end_logits = end_logits.squeeze(-1)
|
| 1575 |
+
|
| 1576 |
+
if not return_dict:
|
| 1577 |
+
return (start_logits, end_logits) + outputs[1:]
|
| 1578 |
+
|
| 1579 |
+
return FlaxQuestionAnsweringModelOutput(
|
| 1580 |
+
start_logits=start_logits,
|
| 1581 |
+
end_logits=end_logits,
|
| 1582 |
+
hidden_states=outputs.hidden_states,
|
| 1583 |
+
attentions=outputs.attentions,
|
| 1584 |
+
)
|
| 1585 |
+
|
| 1586 |
+
|
| 1587 |
+
@add_start_docstrings(
|
| 1588 |
+
"""
|
| 1589 |
+
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1590 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
| 1591 |
+
""",
|
| 1592 |
+
BERT_START_DOCSTRING,
|
| 1593 |
+
)
|
| 1594 |
+
class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
|
| 1595 |
+
module_class = FlaxBertForQuestionAnsweringModule
|
| 1596 |
+
|
| 1597 |
+
|
| 1598 |
+
append_call_sample_docstring(
|
| 1599 |
+
FlaxBertForQuestionAnswering,
|
| 1600 |
+
_CHECKPOINT_FOR_DOC,
|
| 1601 |
+
FlaxQuestionAnsweringModelOutput,
|
| 1602 |
+
_CONFIG_FOR_DOC,
|
| 1603 |
+
)
|
| 1604 |
+
|
| 1605 |
+
|
| 1606 |
+
class FlaxBertForCausalLMModule(nn.Module):
|
| 1607 |
+
config: BertConfig
|
| 1608 |
+
dtype: jnp.dtype = jnp.float32
|
| 1609 |
+
gradient_checkpointing: bool = False
|
| 1610 |
+
|
| 1611 |
+
def setup(self):
|
| 1612 |
+
self.bert = FlaxBertModule(
|
| 1613 |
+
config=self.config,
|
| 1614 |
+
add_pooling_layer=False,
|
| 1615 |
+
dtype=self.dtype,
|
| 1616 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
| 1617 |
+
)
|
| 1618 |
+
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
|
| 1619 |
+
|
| 1620 |
+
def __call__(
|
| 1621 |
+
self,
|
| 1622 |
+
input_ids,
|
| 1623 |
+
attention_mask,
|
| 1624 |
+
position_ids,
|
| 1625 |
+
token_type_ids: Optional[jnp.ndarray] = None,
|
| 1626 |
+
head_mask: Optional[jnp.ndarray] = None,
|
| 1627 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
| 1628 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
| 1629 |
+
init_cache: bool = False,
|
| 1630 |
+
deterministic: bool = True,
|
| 1631 |
+
output_attentions: bool = False,
|
| 1632 |
+
output_hidden_states: bool = False,
|
| 1633 |
+
return_dict: bool = True,
|
| 1634 |
+
):
|
| 1635 |
+
# Model
|
| 1636 |
+
outputs = self.bert(
|
| 1637 |
+
input_ids,
|
| 1638 |
+
attention_mask,
|
| 1639 |
+
token_type_ids,
|
| 1640 |
+
position_ids,
|
| 1641 |
+
head_mask,
|
| 1642 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1643 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1644 |
+
init_cache=init_cache,
|
| 1645 |
+
deterministic=deterministic,
|
| 1646 |
+
output_attentions=output_attentions,
|
| 1647 |
+
output_hidden_states=output_hidden_states,
|
| 1648 |
+
return_dict=return_dict,
|
| 1649 |
+
)
|
| 1650 |
+
|
| 1651 |
+
hidden_states = outputs[0]
|
| 1652 |
+
if self.config.tie_word_embeddings:
|
| 1653 |
+
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
| 1654 |
+
else:
|
| 1655 |
+
shared_embedding = None
|
| 1656 |
+
|
| 1657 |
+
# Compute the prediction scores
|
| 1658 |
+
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
|
| 1659 |
+
|
| 1660 |
+
if not return_dict:
|
| 1661 |
+
return (logits,) + outputs[1:]
|
| 1662 |
+
|
| 1663 |
+
return FlaxCausalLMOutputWithCrossAttentions(
|
| 1664 |
+
logits=logits,
|
| 1665 |
+
hidden_states=outputs.hidden_states,
|
| 1666 |
+
attentions=outputs.attentions,
|
| 1667 |
+
cross_attentions=outputs.cross_attentions,
|
| 1668 |
+
)
|
| 1669 |
+
|
| 1670 |
+
|
| 1671 |
+
@add_start_docstrings(
|
| 1672 |
+
"""
|
| 1673 |
+
Bert Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
|
| 1674 |
+
autoregressive tasks.
|
| 1675 |
+
""",
|
| 1676 |
+
BERT_START_DOCSTRING,
|
| 1677 |
+
)
|
| 1678 |
+
class FlaxBertForCausalLM(FlaxBertPreTrainedModel):
|
| 1679 |
+
module_class = FlaxBertForCausalLMModule
|
| 1680 |
+
|
| 1681 |
+
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
| 1682 |
+
# initializing the cache
|
| 1683 |
+
batch_size, seq_length = input_ids.shape
|
| 1684 |
+
|
| 1685 |
+
past_key_values = self.init_cache(batch_size, max_length)
|
| 1686 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
| 1687 |
+
# But since the decoder uses a causal mask, those positions are masked anyway.
|
| 1688 |
+
# Thus, we can create a single static attention_mask here, which is more efficient for compilation
|
| 1689 |
+
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
| 1690 |
+
if attention_mask is not None:
|
| 1691 |
+
position_ids = attention_mask.cumsum(axis=-1) - 1
|
| 1692 |
+
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
| 1693 |
+
else:
|
| 1694 |
+
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
| 1695 |
+
|
| 1696 |
+
return {
|
| 1697 |
+
"past_key_values": past_key_values,
|
| 1698 |
+
"attention_mask": extended_attention_mask,
|
| 1699 |
+
"position_ids": position_ids,
|
| 1700 |
+
}
|
| 1701 |
+
|
| 1702 |
+
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
| 1703 |
+
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
| 1704 |
+
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
| 1705 |
+
return model_kwargs
|
| 1706 |
+
|
| 1707 |
+
|
| 1708 |
+
append_call_sample_docstring(
|
| 1709 |
+
FlaxBertForCausalLM,
|
| 1710 |
+
_CHECKPOINT_FOR_DOC,
|
| 1711 |
+
FlaxCausalLMOutputWithCrossAttentions,
|
| 1712 |
+
_CONFIG_FOR_DOC,
|
| 1713 |
+
)
|
| 1714 |
+
|
| 1715 |
+
|
| 1716 |
+
__all__ = [
|
| 1717 |
+
"FlaxBertForCausalLM",
|
| 1718 |
+
"FlaxBertForMaskedLM",
|
| 1719 |
+
"FlaxBertForMultipleChoice",
|
| 1720 |
+
"FlaxBertForNextSentencePrediction",
|
| 1721 |
+
"FlaxBertForPreTraining",
|
| 1722 |
+
"FlaxBertForQuestionAnswering",
|
| 1723 |
+
"FlaxBertForSequenceClassification",
|
| 1724 |
+
"FlaxBertForTokenClassification",
|
| 1725 |
+
"FlaxBertModel",
|
| 1726 |
+
"FlaxBertPreTrainedModel",
|
| 1727 |
+
]
|