Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/transformers/models/barthez/__init__.py +27 -0
- .venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/tokenization_barthez.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/tokenization_barthez_fast.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/barthez/tokenization_barthez.py +289 -0
- .venv/lib/python3.11/site-packages/transformers/models/barthez/tokenization_barthez_fast.py +197 -0
- .venv/lib/python3.11/site-packages/transformers/models/beit/configuration_beit.py +229 -0
- .venv/lib/python3.11/site-packages/transformers/models/beit/feature_extraction_beit.py +36 -0
- .venv/lib/python3.11/site-packages/transformers/models/beit/image_processing_beit.py +515 -0
- .venv/lib/python3.11/site-packages/transformers/models/beit/modeling_flax_beit.py +956 -0
- .venv/lib/python3.11/site-packages/transformers/models/code_llama/__init__.py +27 -0
- .venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/tokenization_code_llama.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/tokenization_code_llama_fast.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama.py +452 -0
- .venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py +381 -0
- .venv/lib/python3.11/site-packages/transformers/models/convnext/__init__.py +30 -0
- .venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/configuration_convnext.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/image_processing_convnext.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/convnext/configuration_convnext.py +142 -0
- .venv/lib/python3.11/site-packages/transformers/models/convnext/feature_extraction_convnext.py +36 -0
- .venv/lib/python3.11/site-packages/transformers/models/convnext/image_processing_convnext.py +323 -0
- .venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_convnext.py +551 -0
- .venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_tf_convnext.py +669 -0
- .venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__init__.py +27 -0
- .venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/configuration_decision_transformer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/modeling_decision_transformer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/decision_transformer/modeling_decision_transformer.py +963 -0
- .venv/lib/python3.11/site-packages/transformers/models/focalnet/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__init__.py +26 -0
- .venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/tokenization_gpt_sw3.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/tokenization_gpt_sw3.py +299 -0
- .venv/lib/python3.11/site-packages/transformers/models/musicgen/__init__.py +28 -0
- .venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/configuration_musicgen.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/processing_musicgen.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/musicgen/configuration_musicgen.py +247 -0
- .venv/lib/python3.11/site-packages/transformers/models/musicgen/modeling_musicgen.py +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/musicgen/processing_musicgen.py +144 -0
- .venv/lib/python3.11/site-packages/transformers/models/olmoe/__init__.py +27 -0
- .venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/configuration_olmoe.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/modeling_olmoe.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/olmoe/configuration_olmoe.py +182 -0
- .venv/lib/python3.11/site-packages/transformers/models/olmoe/modeling_olmoe.py +1299 -0
- .venv/lib/python3.11/site-packages/transformers/models/pegasus/__init__.py +31 -0
- .venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/configuration_pegasus.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/transformers/models/barthez/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .tokenization_barthez import *
|
| 22 |
+
from .tokenization_barthez_fast import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
.venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (777 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/tokenization_barthez.cpython-311.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/tokenization_barthez_fast.cpython-311.pyc
ADDED
|
Binary file (8.97 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/barthez/tokenization_barthez.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020 Ecole Polytechnique and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License
|
| 15 |
+
"""Tokenization classes for the BARThez model."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from shutil import copyfile
|
| 19 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 20 |
+
|
| 21 |
+
import sentencepiece as spm
|
| 22 |
+
|
| 23 |
+
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 24 |
+
from ...utils import logging
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
SPIECE_UNDERLINE = "▁"
|
| 33 |
+
|
| 34 |
+
# TODO this class is useless. This is the most standard sentencpiece model. Let's find which one is closest and nuke this.
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class BarthezTokenizer(PreTrainedTokenizer):
|
| 38 |
+
"""
|
| 39 |
+
Adapted from [`CamembertTokenizer`] and [`BartTokenizer`]. Construct a BARThez tokenizer. Based on
|
| 40 |
+
[SentencePiece](https://github.com/google/sentencepiece).
|
| 41 |
+
|
| 42 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 43 |
+
this superclass for more information regarding those methods.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
vocab_file (`str`):
|
| 47 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
|
| 48 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
| 49 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 50 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 51 |
+
|
| 52 |
+
<Tip>
|
| 53 |
+
|
| 54 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
| 55 |
+
sequence. The token used is the `cls_token`.
|
| 56 |
+
|
| 57 |
+
</Tip>
|
| 58 |
+
|
| 59 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 60 |
+
The end of sequence token.
|
| 61 |
+
|
| 62 |
+
<Tip>
|
| 63 |
+
|
| 64 |
+
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
| 65 |
+
The token used is the `sep_token`.
|
| 66 |
+
|
| 67 |
+
</Tip>
|
| 68 |
+
|
| 69 |
+
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
| 70 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 71 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 72 |
+
token of a sequence built with special tokens.
|
| 73 |
+
cls_token (`str`, *optional*, defaults to `"<s>"`):
|
| 74 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 75 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 76 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 77 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 78 |
+
token instead.
|
| 79 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 80 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 81 |
+
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
| 82 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 83 |
+
modeling. This is the token which the model will try to predict.
|
| 84 |
+
sp_model_kwargs (`dict`, *optional*):
|
| 85 |
+
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
| 86 |
+
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
| 87 |
+
to set:
|
| 88 |
+
|
| 89 |
+
- `enable_sampling`: Enable subword regularization.
|
| 90 |
+
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
| 91 |
+
|
| 92 |
+
- `nbest_size = {0,1}`: No sampling is performed.
|
| 93 |
+
- `nbest_size > 1`: samples from the nbest_size results.
|
| 94 |
+
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
| 95 |
+
using forward-filtering-and-backward-sampling algorithm.
|
| 96 |
+
|
| 97 |
+
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
| 98 |
+
BPE-dropout.
|
| 99 |
+
|
| 100 |
+
Attributes:
|
| 101 |
+
sp_model (`SentencePieceProcessor`):
|
| 102 |
+
The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 106 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
vocab_file,
|
| 111 |
+
bos_token="<s>",
|
| 112 |
+
eos_token="</s>",
|
| 113 |
+
sep_token="</s>",
|
| 114 |
+
cls_token="<s>",
|
| 115 |
+
unk_token="<unk>",
|
| 116 |
+
pad_token="<pad>",
|
| 117 |
+
mask_token="<mask>",
|
| 118 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
| 119 |
+
**kwargs,
|
| 120 |
+
) -> None:
|
| 121 |
+
# Mask token behave like a normal word, i.e. include the space before it. Will have normalized=False by default this way
|
| 122 |
+
mask_token = AddedToken(mask_token, lstrip=True, special=True) if isinstance(mask_token, str) else mask_token
|
| 123 |
+
|
| 124 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 125 |
+
|
| 126 |
+
self.vocab_file = vocab_file
|
| 127 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 128 |
+
self.sp_model.Load(str(vocab_file))
|
| 129 |
+
super().__init__(
|
| 130 |
+
bos_token=bos_token,
|
| 131 |
+
eos_token=eos_token,
|
| 132 |
+
unk_token=unk_token,
|
| 133 |
+
sep_token=sep_token,
|
| 134 |
+
cls_token=cls_token,
|
| 135 |
+
pad_token=pad_token,
|
| 136 |
+
mask_token=mask_token,
|
| 137 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
| 138 |
+
**kwargs,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def build_inputs_with_special_tokens(
|
| 142 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 143 |
+
) -> List[int]:
|
| 144 |
+
"""
|
| 145 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 146 |
+
adding special tokens. A BARThez sequence has the following format:
|
| 147 |
+
|
| 148 |
+
- single sequence: `<s> X </s>`
|
| 149 |
+
- pair of sequences: `<s> A </s></s> B </s>`
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
token_ids_0 (`List[int]`):
|
| 153 |
+
List of IDs to which the special tokens will be added.
|
| 154 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 155 |
+
Optional second list of IDs for sequence pairs.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
if token_ids_1 is None:
|
| 162 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 163 |
+
cls = [self.cls_token_id]
|
| 164 |
+
sep = [self.sep_token_id]
|
| 165 |
+
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
|
| 166 |
+
|
| 167 |
+
def get_special_tokens_mask(
|
| 168 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 169 |
+
) -> List[int]:
|
| 170 |
+
"""
|
| 171 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 172 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
token_ids_0 (`List[int]`):
|
| 176 |
+
List of IDs.
|
| 177 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 178 |
+
Optional second list of IDs for sequence pairs.
|
| 179 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 180 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 184 |
+
"""
|
| 185 |
+
if already_has_special_tokens:
|
| 186 |
+
return super().get_special_tokens_mask(
|
| 187 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
if token_ids_1 is None:
|
| 191 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 192 |
+
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
|
| 193 |
+
|
| 194 |
+
def create_token_type_ids_from_sequences(
|
| 195 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 196 |
+
) -> List[int]:
|
| 197 |
+
"""
|
| 198 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
token_ids_0 (`List[int]`):
|
| 202 |
+
List of IDs.
|
| 203 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 204 |
+
Optional second list of IDs for sequence pairs.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
`List[int]`: List of zeros.
|
| 208 |
+
"""
|
| 209 |
+
sep = [self.sep_token_id]
|
| 210 |
+
cls = [self.cls_token_id]
|
| 211 |
+
|
| 212 |
+
if token_ids_1 is None:
|
| 213 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 214 |
+
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
| 215 |
+
|
| 216 |
+
@property
|
| 217 |
+
def vocab_size(self):
|
| 218 |
+
return len(self.sp_model)
|
| 219 |
+
|
| 220 |
+
def get_vocab(self):
|
| 221 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 222 |
+
vocab.update(self.added_tokens_encoder)
|
| 223 |
+
return vocab
|
| 224 |
+
|
| 225 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 226 |
+
return self.sp_model.encode(text, out_type=str)
|
| 227 |
+
|
| 228 |
+
def _convert_token_to_id(self, token):
|
| 229 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 230 |
+
return self.sp_model.PieceToId(token)
|
| 231 |
+
|
| 232 |
+
def _convert_id_to_token(self, index):
|
| 233 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 234 |
+
return self.sp_model.IdToPiece(index)
|
| 235 |
+
|
| 236 |
+
# Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.convert_tokens_to_string
|
| 237 |
+
def convert_tokens_to_string(self, tokens):
|
| 238 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 239 |
+
current_sub_tokens = []
|
| 240 |
+
out_string = ""
|
| 241 |
+
prev_is_special = False
|
| 242 |
+
for token in tokens:
|
| 243 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
| 244 |
+
if token in self.all_special_tokens:
|
| 245 |
+
if not prev_is_special:
|
| 246 |
+
out_string += " "
|
| 247 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
| 248 |
+
prev_is_special = True
|
| 249 |
+
current_sub_tokens = []
|
| 250 |
+
else:
|
| 251 |
+
current_sub_tokens.append(token)
|
| 252 |
+
prev_is_special = False
|
| 253 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 254 |
+
return out_string.strip()
|
| 255 |
+
|
| 256 |
+
def __getstate__(self):
|
| 257 |
+
state = self.__dict__.copy()
|
| 258 |
+
state["sp_model"] = None
|
| 259 |
+
return state
|
| 260 |
+
|
| 261 |
+
def __setstate__(self, d):
|
| 262 |
+
self.__dict__ = d
|
| 263 |
+
|
| 264 |
+
# for backward compatibility
|
| 265 |
+
if not hasattr(self, "sp_model_kwargs"):
|
| 266 |
+
self.sp_model_kwargs = {}
|
| 267 |
+
|
| 268 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 269 |
+
self.sp_model.Load(self.vocab_file)
|
| 270 |
+
|
| 271 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 272 |
+
if not os.path.isdir(save_directory):
|
| 273 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 274 |
+
return
|
| 275 |
+
out_vocab_file = os.path.join(
|
| 276 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 280 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 281 |
+
elif not os.path.isfile(self.vocab_file):
|
| 282 |
+
with open(out_vocab_file, "wb") as fi:
|
| 283 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 284 |
+
fi.write(content_spiece_model)
|
| 285 |
+
|
| 286 |
+
return (out_vocab_file,)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
__all__ = ["BarthezTokenizer"]
|
.venv/lib/python3.11/site-packages/transformers/models/barthez/tokenization_barthez_fast.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020 Ecole Polytechnique and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License
|
| 15 |
+
"""Tokenization classes for the BARThez model."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from shutil import copyfile
|
| 19 |
+
from typing import List, Optional, Tuple
|
| 20 |
+
|
| 21 |
+
from ...tokenization_utils import AddedToken
|
| 22 |
+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
| 23 |
+
from ...utils import is_sentencepiece_available, logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_sentencepiece_available():
|
| 27 |
+
from .tokenization_barthez import BarthezTokenizer
|
| 28 |
+
else:
|
| 29 |
+
BarthezTokenizer = None
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
SPIECE_UNDERLINE = "▁"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BarthezTokenizerFast(PreTrainedTokenizerFast):
|
| 40 |
+
"""
|
| 41 |
+
Adapted from [`CamembertTokenizer`] and [`BartTokenizer`]. Construct a "fast" BARThez tokenizer. Based on
|
| 42 |
+
[SentencePiece](https://github.com/google/sentencepiece).
|
| 43 |
+
|
| 44 |
+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
| 45 |
+
refer to this superclass for more information regarding those methods.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
vocab_file (`str`):
|
| 49 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
|
| 50 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
| 51 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 52 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 53 |
+
|
| 54 |
+
<Tip>
|
| 55 |
+
|
| 56 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
| 57 |
+
sequence. The token used is the `cls_token`.
|
| 58 |
+
|
| 59 |
+
</Tip>
|
| 60 |
+
|
| 61 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 62 |
+
The end of sequence token.
|
| 63 |
+
|
| 64 |
+
<Tip>
|
| 65 |
+
|
| 66 |
+
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
| 67 |
+
The token used is the `sep_token`.
|
| 68 |
+
|
| 69 |
+
</Tip>
|
| 70 |
+
|
| 71 |
+
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
| 72 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
| 73 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
| 74 |
+
token of a sequence built with special tokens.
|
| 75 |
+
cls_token (`str`, *optional*, defaults to `"<s>"`):
|
| 76 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
| 77 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
| 78 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 79 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 80 |
+
token instead.
|
| 81 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
| 82 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 83 |
+
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
| 84 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 85 |
+
modeling. This is the token which the model will try to predict.
|
| 86 |
+
additional_special_tokens (`List[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
|
| 87 |
+
Additional special tokens used by the tokenizer.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 91 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 92 |
+
slow_tokenizer_class = BarthezTokenizer
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
vocab_file=None,
|
| 97 |
+
tokenizer_file=None,
|
| 98 |
+
bos_token="<s>",
|
| 99 |
+
eos_token="</s>",
|
| 100 |
+
sep_token="</s>",
|
| 101 |
+
cls_token="<s>",
|
| 102 |
+
unk_token="<unk>",
|
| 103 |
+
pad_token="<pad>",
|
| 104 |
+
mask_token="<mask>",
|
| 105 |
+
**kwargs,
|
| 106 |
+
):
|
| 107 |
+
# Mask token behave like a normal word, i.e. include the space before it
|
| 108 |
+
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
| 109 |
+
|
| 110 |
+
super().__init__(
|
| 111 |
+
vocab_file,
|
| 112 |
+
tokenizer_file=tokenizer_file,
|
| 113 |
+
bos_token=bos_token,
|
| 114 |
+
eos_token=eos_token,
|
| 115 |
+
unk_token=unk_token,
|
| 116 |
+
sep_token=sep_token,
|
| 117 |
+
cls_token=cls_token,
|
| 118 |
+
pad_token=pad_token,
|
| 119 |
+
mask_token=mask_token,
|
| 120 |
+
**kwargs,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
self.vocab_file = vocab_file
|
| 124 |
+
|
| 125 |
+
@property
|
| 126 |
+
def can_save_slow_tokenizer(self) -> bool:
|
| 127 |
+
return os.path.isfile(self.vocab_file) if self.vocab_file else False
|
| 128 |
+
|
| 129 |
+
def build_inputs_with_special_tokens(
|
| 130 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 131 |
+
) -> List[int]:
|
| 132 |
+
"""
|
| 133 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 134 |
+
adding special tokens. A BARThez sequence has the following format:
|
| 135 |
+
|
| 136 |
+
- single sequence: `<s> X </s>`
|
| 137 |
+
- pair of sequences: `<s> A </s></s> B </s>`
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
token_ids_0 (`List[int]`):
|
| 141 |
+
List of IDs to which the special tokens will be added.
|
| 142 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 143 |
+
Optional second list of IDs for sequence pairs.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
if token_ids_1 is None:
|
| 150 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 151 |
+
cls = [self.cls_token_id]
|
| 152 |
+
sep = [self.sep_token_id]
|
| 153 |
+
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
|
| 154 |
+
|
| 155 |
+
def create_token_type_ids_from_sequences(
|
| 156 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 157 |
+
) -> List[int]:
|
| 158 |
+
"""
|
| 159 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
token_ids_0 (`List[int]`):
|
| 163 |
+
List of IDs.
|
| 164 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 165 |
+
Optional second list of IDs for sequence pairs.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
`List[int]`: List of zeros.
|
| 169 |
+
"""
|
| 170 |
+
sep = [self.sep_token_id]
|
| 171 |
+
cls = [self.cls_token_id]
|
| 172 |
+
|
| 173 |
+
if token_ids_1 is None:
|
| 174 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 175 |
+
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
| 176 |
+
|
| 177 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 178 |
+
if not self.can_save_slow_tokenizer:
|
| 179 |
+
raise ValueError(
|
| 180 |
+
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
| 181 |
+
"tokenizer."
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if not os.path.isdir(save_directory):
|
| 185 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 186 |
+
return
|
| 187 |
+
out_vocab_file = os.path.join(
|
| 188 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
| 192 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 193 |
+
|
| 194 |
+
return (out_vocab_file,)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
__all__ = ["BarthezTokenizerFast"]
|
.venv/lib/python3.11/site-packages/transformers/models/beit/configuration_beit.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""BEiT model configuration"""
|
| 16 |
+
|
| 17 |
+
import warnings
|
| 18 |
+
from collections import OrderedDict
|
| 19 |
+
from typing import Mapping
|
| 20 |
+
|
| 21 |
+
from packaging import version
|
| 22 |
+
|
| 23 |
+
from ...configuration_utils import PretrainedConfig
|
| 24 |
+
from ...onnx import OnnxConfig
|
| 25 |
+
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BeitConfig(BackboneConfigMixin, PretrainedConfig):
|
| 29 |
+
r"""
|
| 30 |
+
This is the configuration class to store the configuration of a [`BeitModel`]. It is used to instantiate an BEiT
|
| 31 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 32 |
+
defaults will yield a similar configuration to that of the BEiT
|
| 33 |
+
[microsoft/beit-base-patch16-224-pt22k](https://huggingface.co/microsoft/beit-base-patch16-224-pt22k) architecture.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vocab_size (`int`, *optional*, defaults to 8192):
|
| 37 |
+
Vocabulary size of the BEiT model. Defines the number of different image tokens that can be used during
|
| 38 |
+
pre-training.
|
| 39 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 40 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 41 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 42 |
+
Number of hidden layers in the Transformer encoder.
|
| 43 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 44 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 45 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 46 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 47 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 48 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 49 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 50 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 51 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 52 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 53 |
+
The dropout ratio for the attention probabilities.
|
| 54 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 55 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 56 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 57 |
+
The epsilon used by the layer normalization layers.
|
| 58 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 59 |
+
The size (resolution) of each image.
|
| 60 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 61 |
+
The size (resolution) of each patch.
|
| 62 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 63 |
+
The number of input channels.
|
| 64 |
+
use_mask_token (`bool`, *optional*, defaults to `False`):
|
| 65 |
+
Whether to use a mask token for masked image modeling.
|
| 66 |
+
use_absolute_position_embeddings (`bool`, *optional*, defaults to `False`):
|
| 67 |
+
Whether to use BERT-style absolute position embeddings.
|
| 68 |
+
use_relative_position_bias (`bool`, *optional*, defaults to `False`):
|
| 69 |
+
Whether to use T5-style relative position embeddings in the self-attention layers.
|
| 70 |
+
use_shared_relative_position_bias (`bool`, *optional*, defaults to `False`):
|
| 71 |
+
Whether to use the same relative position embeddings across all self-attention layers of the Transformer.
|
| 72 |
+
layer_scale_init_value (`float`, *optional*, defaults to 0.1):
|
| 73 |
+
Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.
|
| 74 |
+
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
| 75 |
+
Stochastic depth rate per sample (when applied in the main path of residual layers).
|
| 76 |
+
use_mean_pooling (`bool`, *optional*, defaults to `True`):
|
| 77 |
+
Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
|
| 78 |
+
CLS token, before applying the classification head.
|
| 79 |
+
pool_scales (`Tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`):
|
| 80 |
+
Pooling scales used in Pooling Pyramid Module applied on the last feature map.
|
| 81 |
+
use_auxiliary_head (`bool`, *optional*, defaults to `True`):
|
| 82 |
+
Whether to use an auxiliary head during training.
|
| 83 |
+
auxiliary_loss_weight (`float`, *optional*, defaults to 0.4):
|
| 84 |
+
Weight of the cross-entropy loss of the auxiliary head.
|
| 85 |
+
auxiliary_channels (`int`, *optional*, defaults to 256):
|
| 86 |
+
Number of channels to use in the auxiliary head.
|
| 87 |
+
auxiliary_num_convs (`int`, *optional*, defaults to 1):
|
| 88 |
+
Number of convolutional layers to use in the auxiliary head.
|
| 89 |
+
auxiliary_concat_input (`bool`, *optional*, defaults to `False`):
|
| 90 |
+
Whether to concatenate the output of the auxiliary head with the input before the classification layer.
|
| 91 |
+
semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
|
| 92 |
+
The index that is ignored by the loss function of the semantic segmentation model.
|
| 93 |
+
out_features (`List[str]`, *optional*):
|
| 94 |
+
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
| 95 |
+
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
|
| 96 |
+
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
|
| 97 |
+
same order as defined in the `stage_names` attribute.
|
| 98 |
+
out_indices (`List[int]`, *optional*):
|
| 99 |
+
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
|
| 100 |
+
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
|
| 101 |
+
If unset and `out_features` is unset, will default to the last stage. Must be in the
|
| 102 |
+
same order as defined in the `stage_names` attribute.
|
| 103 |
+
add_fpn (`bool`, *optional*, defaults to `False`):
|
| 104 |
+
Whether to add a FPN as part of the backbone. Only relevant for [`BeitBackbone`].
|
| 105 |
+
reshape_hidden_states (`bool`, *optional*, defaults to `True`):
|
| 106 |
+
Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
|
| 107 |
+
case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
|
| 108 |
+
seq_len, hidden_size)`. Only relevant for [`BeitBackbone`].
|
| 109 |
+
|
| 110 |
+
Example:
|
| 111 |
+
|
| 112 |
+
```python
|
| 113 |
+
>>> from transformers import BeitConfig, BeitModel
|
| 114 |
+
|
| 115 |
+
>>> # Initializing a BEiT beit-base-patch16-224-pt22k style configuration
|
| 116 |
+
>>> configuration = BeitConfig()
|
| 117 |
+
|
| 118 |
+
>>> # Initializing a model (with random weights) from the beit-base-patch16-224-pt22k style configuration
|
| 119 |
+
>>> model = BeitModel(configuration)
|
| 120 |
+
|
| 121 |
+
>>> # Accessing the model configuration
|
| 122 |
+
>>> configuration = model.config
|
| 123 |
+
```"""
|
| 124 |
+
|
| 125 |
+
model_type = "beit"
|
| 126 |
+
|
| 127 |
+
def __init__(
|
| 128 |
+
self,
|
| 129 |
+
vocab_size=8192,
|
| 130 |
+
hidden_size=768,
|
| 131 |
+
num_hidden_layers=12,
|
| 132 |
+
num_attention_heads=12,
|
| 133 |
+
intermediate_size=3072,
|
| 134 |
+
hidden_act="gelu",
|
| 135 |
+
hidden_dropout_prob=0.0,
|
| 136 |
+
attention_probs_dropout_prob=0.0,
|
| 137 |
+
initializer_range=0.02,
|
| 138 |
+
layer_norm_eps=1e-12,
|
| 139 |
+
image_size=224,
|
| 140 |
+
patch_size=16,
|
| 141 |
+
num_channels=3,
|
| 142 |
+
use_mask_token=False,
|
| 143 |
+
use_absolute_position_embeddings=False,
|
| 144 |
+
use_relative_position_bias=False,
|
| 145 |
+
use_shared_relative_position_bias=False,
|
| 146 |
+
layer_scale_init_value=0.1,
|
| 147 |
+
drop_path_rate=0.1,
|
| 148 |
+
use_mean_pooling=True,
|
| 149 |
+
pool_scales=[1, 2, 3, 6],
|
| 150 |
+
use_auxiliary_head=True,
|
| 151 |
+
auxiliary_loss_weight=0.4,
|
| 152 |
+
auxiliary_channels=256,
|
| 153 |
+
auxiliary_num_convs=1,
|
| 154 |
+
auxiliary_concat_input=False,
|
| 155 |
+
semantic_loss_ignore_index=255,
|
| 156 |
+
out_features=None,
|
| 157 |
+
out_indices=None,
|
| 158 |
+
add_fpn=False,
|
| 159 |
+
reshape_hidden_states=True,
|
| 160 |
+
**kwargs,
|
| 161 |
+
):
|
| 162 |
+
super().__init__(**kwargs)
|
| 163 |
+
|
| 164 |
+
self.vocab_size = vocab_size
|
| 165 |
+
self.hidden_size = hidden_size
|
| 166 |
+
self.num_hidden_layers = num_hidden_layers
|
| 167 |
+
self.num_attention_heads = num_attention_heads
|
| 168 |
+
self.intermediate_size = intermediate_size
|
| 169 |
+
self.hidden_act = hidden_act
|
| 170 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 171 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 172 |
+
self.initializer_range = initializer_range
|
| 173 |
+
self.layer_norm_eps = layer_norm_eps
|
| 174 |
+
|
| 175 |
+
self.image_size = image_size
|
| 176 |
+
self.patch_size = patch_size
|
| 177 |
+
self.num_channels = num_channels
|
| 178 |
+
self.use_mask_token = use_mask_token
|
| 179 |
+
self.use_absolute_position_embeddings = use_absolute_position_embeddings
|
| 180 |
+
self.use_relative_position_bias = use_relative_position_bias
|
| 181 |
+
self.use_shared_relative_position_bias = use_shared_relative_position_bias
|
| 182 |
+
self.layer_scale_init_value = layer_scale_init_value
|
| 183 |
+
self.drop_path_rate = drop_path_rate
|
| 184 |
+
self.use_mean_pooling = use_mean_pooling
|
| 185 |
+
# decode head attributes (semantic segmentation)
|
| 186 |
+
self.pool_scales = pool_scales
|
| 187 |
+
# auxiliary head attributes (semantic segmentation)
|
| 188 |
+
self.use_auxiliary_head = use_auxiliary_head
|
| 189 |
+
self.auxiliary_loss_weight = auxiliary_loss_weight
|
| 190 |
+
self.auxiliary_channels = auxiliary_channels
|
| 191 |
+
self.auxiliary_num_convs = auxiliary_num_convs
|
| 192 |
+
self.auxiliary_concat_input = auxiliary_concat_input
|
| 193 |
+
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
| 194 |
+
|
| 195 |
+
# handle backwards compatibility
|
| 196 |
+
if "segmentation_indices" in kwargs:
|
| 197 |
+
warnings.warn(
|
| 198 |
+
"The `segmentation_indices` argument is deprecated and will be removed in a future version, use `out_indices` instead.",
|
| 199 |
+
FutureWarning,
|
| 200 |
+
)
|
| 201 |
+
out_indices = kwargs.pop("segmentation_indices")
|
| 202 |
+
|
| 203 |
+
# backbone attributes
|
| 204 |
+
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, self.num_hidden_layers + 1)]
|
| 205 |
+
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
| 206 |
+
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
| 207 |
+
)
|
| 208 |
+
self.add_fpn = add_fpn
|
| 209 |
+
self.reshape_hidden_states = reshape_hidden_states
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig
|
| 213 |
+
class BeitOnnxConfig(OnnxConfig):
|
| 214 |
+
torch_onnx_minimum_version = version.parse("1.11")
|
| 215 |
+
|
| 216 |
+
@property
|
| 217 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 218 |
+
return OrderedDict(
|
| 219 |
+
[
|
| 220 |
+
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
| 221 |
+
]
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
@property
|
| 225 |
+
def atol_for_validation(self) -> float:
|
| 226 |
+
return 1e-4
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
__all__ = ["BeitConfig", "BeitOnnxConfig"]
|
.venv/lib/python3.11/site-packages/transformers/models/beit/feature_extraction_beit.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Feature extractor class for BEiT."""
|
| 16 |
+
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
from ...utils import logging
|
| 20 |
+
from .image_processing_beit import BeitImageProcessor
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BeitFeatureExtractor(BeitImageProcessor):
|
| 27 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 28 |
+
warnings.warn(
|
| 29 |
+
"The class BeitFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
|
| 30 |
+
" use BeitImageProcessor instead.",
|
| 31 |
+
FutureWarning,
|
| 32 |
+
)
|
| 33 |
+
super().__init__(*args, **kwargs)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
__all__ = ["BeitFeatureExtractor"]
|
.venv/lib/python3.11/site-packages/transformers/models/beit/image_processing_beit.py
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for Beit."""
|
| 16 |
+
|
| 17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict
|
| 22 |
+
from ...image_transforms import resize, to_channel_dimension_format
|
| 23 |
+
from ...image_utils import (
|
| 24 |
+
IMAGENET_STANDARD_MEAN,
|
| 25 |
+
IMAGENET_STANDARD_STD,
|
| 26 |
+
ChannelDimension,
|
| 27 |
+
ImageInput,
|
| 28 |
+
PILImageResampling,
|
| 29 |
+
infer_channel_dimension_format,
|
| 30 |
+
is_scaled_image,
|
| 31 |
+
make_list_of_images,
|
| 32 |
+
to_numpy_array,
|
| 33 |
+
valid_images,
|
| 34 |
+
validate_preprocess_arguments,
|
| 35 |
+
)
|
| 36 |
+
from ...utils import (
|
| 37 |
+
TensorType,
|
| 38 |
+
filter_out_non_signature_kwargs,
|
| 39 |
+
is_torch_available,
|
| 40 |
+
is_torch_tensor,
|
| 41 |
+
is_vision_available,
|
| 42 |
+
logging,
|
| 43 |
+
)
|
| 44 |
+
from ...utils.deprecation import deprecate_kwarg
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if is_vision_available():
|
| 48 |
+
import PIL
|
| 49 |
+
|
| 50 |
+
if is_torch_available():
|
| 51 |
+
import torch
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
logger = logging.get_logger(__name__)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class BeitImageProcessor(BaseImageProcessor):
|
| 58 |
+
r"""
|
| 59 |
+
Constructs a BEiT image processor.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 63 |
+
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
| 64 |
+
`do_resize` parameter in the `preprocess` method.
|
| 65 |
+
size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
|
| 66 |
+
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
| 67 |
+
method.
|
| 68 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
| 69 |
+
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
| 70 |
+
`preprocess` method.
|
| 71 |
+
do_center_crop (`bool`, *optional*, defaults to `True`):
|
| 72 |
+
Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
|
| 73 |
+
is padded with 0's and then center cropped. Can be overridden by the `do_center_crop` parameter in the
|
| 74 |
+
`preprocess` method.
|
| 75 |
+
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
| 76 |
+
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
|
| 77 |
+
Can be overridden by the `crop_size` parameter in the `preprocess` method.
|
| 78 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 79 |
+
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
| 80 |
+
`preprocess` method.
|
| 81 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 82 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
| 83 |
+
parameter in the `preprocess` method.
|
| 84 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 85 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
| 86 |
+
method.
|
| 87 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
| 88 |
+
The mean to use if normalizing the image. This is a float or list of floats of length of the number of
|
| 89 |
+
channels of the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 90 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
| 91 |
+
The standard deviation to use if normalizing the image. This is a float or list of floats of length of the
|
| 92 |
+
number of channels of the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 93 |
+
do_reduce_labels (`bool`, *optional*, defaults to `False`):
|
| 94 |
+
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
|
| 95 |
+
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
|
| 96 |
+
background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
|
| 97 |
+
`preprocess` method.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
model_input_names = ["pixel_values"]
|
| 101 |
+
|
| 102 |
+
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0")
|
| 103 |
+
@filter_out_non_signature_kwargs(extra=INIT_SERVICE_KWARGS)
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
do_resize: bool = True,
|
| 107 |
+
size: Dict[str, int] = None,
|
| 108 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 109 |
+
do_center_crop: bool = True,
|
| 110 |
+
crop_size: Dict[str, int] = None,
|
| 111 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 112 |
+
do_rescale: bool = True,
|
| 113 |
+
do_normalize: bool = True,
|
| 114 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 115 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 116 |
+
do_reduce_labels: bool = False,
|
| 117 |
+
**kwargs,
|
| 118 |
+
) -> None:
|
| 119 |
+
super().__init__(**kwargs)
|
| 120 |
+
size = size if size is not None else {"height": 256, "width": 256}
|
| 121 |
+
size = get_size_dict(size)
|
| 122 |
+
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
| 123 |
+
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
| 124 |
+
self.do_resize = do_resize
|
| 125 |
+
self.size = size
|
| 126 |
+
self.resample = resample
|
| 127 |
+
self.do_center_crop = do_center_crop
|
| 128 |
+
self.crop_size = crop_size
|
| 129 |
+
self.do_rescale = do_rescale
|
| 130 |
+
self.rescale_factor = rescale_factor
|
| 131 |
+
self.do_normalize = do_normalize
|
| 132 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
| 133 |
+
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
| 134 |
+
self.do_reduce_labels = do_reduce_labels
|
| 135 |
+
|
| 136 |
+
@classmethod
|
| 137 |
+
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
|
| 138 |
+
"""
|
| 139 |
+
Overrides the `from_dict` method from the base class to save support of deprecated `reduce_labels` in old configs
|
| 140 |
+
"""
|
| 141 |
+
image_processor_dict = image_processor_dict.copy()
|
| 142 |
+
if "reduce_labels" in image_processor_dict:
|
| 143 |
+
image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels")
|
| 144 |
+
return super().from_dict(image_processor_dict, **kwargs)
|
| 145 |
+
|
| 146 |
+
def resize(
|
| 147 |
+
self,
|
| 148 |
+
image: np.ndarray,
|
| 149 |
+
size: Dict[str, int],
|
| 150 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 151 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 152 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 153 |
+
**kwargs,
|
| 154 |
+
) -> np.ndarray:
|
| 155 |
+
"""
|
| 156 |
+
Resize an image to (size["height"], size["width"]).
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
image (`np.ndarray`):
|
| 160 |
+
Image to resize.
|
| 161 |
+
size (`Dict[str, int]`):
|
| 162 |
+
Size of the output image.
|
| 163 |
+
resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BICUBIC`):
|
| 164 |
+
Resampling filter to use when resiizing the image.
|
| 165 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 166 |
+
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
| 167 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 168 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 169 |
+
"""
|
| 170 |
+
size = get_size_dict(size, default_to_square=True, param_name="size")
|
| 171 |
+
if "height" not in size or "width" not in size:
|
| 172 |
+
raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}")
|
| 173 |
+
return resize(
|
| 174 |
+
image,
|
| 175 |
+
size=(size["height"], size["width"]),
|
| 176 |
+
resample=resample,
|
| 177 |
+
data_format=data_format,
|
| 178 |
+
input_data_format=input_data_format,
|
| 179 |
+
**kwargs,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def reduce_label(self, label: ImageInput) -> np.ndarray:
|
| 183 |
+
label = to_numpy_array(label)
|
| 184 |
+
# Avoid using underflow conversion
|
| 185 |
+
label[label == 0] = 255
|
| 186 |
+
label = label - 1
|
| 187 |
+
label[label == 254] = 255
|
| 188 |
+
return label
|
| 189 |
+
|
| 190 |
+
def _preprocess(
|
| 191 |
+
self,
|
| 192 |
+
image: ImageInput,
|
| 193 |
+
do_reduce_labels: bool = None,
|
| 194 |
+
do_resize: bool = None,
|
| 195 |
+
size: Dict[str, int] = None,
|
| 196 |
+
resample: PILImageResampling = None,
|
| 197 |
+
do_center_crop: bool = None,
|
| 198 |
+
crop_size: Dict[str, int] = None,
|
| 199 |
+
do_rescale: bool = None,
|
| 200 |
+
rescale_factor: float = None,
|
| 201 |
+
do_normalize: bool = None,
|
| 202 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 203 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 204 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 205 |
+
):
|
| 206 |
+
if do_reduce_labels:
|
| 207 |
+
image = self.reduce_label(image)
|
| 208 |
+
|
| 209 |
+
if do_resize:
|
| 210 |
+
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
| 211 |
+
|
| 212 |
+
if do_center_crop:
|
| 213 |
+
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
|
| 214 |
+
|
| 215 |
+
if do_rescale:
|
| 216 |
+
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 217 |
+
|
| 218 |
+
if do_normalize:
|
| 219 |
+
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 220 |
+
|
| 221 |
+
return image
|
| 222 |
+
|
| 223 |
+
def _preprocess_image(
|
| 224 |
+
self,
|
| 225 |
+
image: ImageInput,
|
| 226 |
+
do_resize: bool = None,
|
| 227 |
+
size: Dict[str, int] = None,
|
| 228 |
+
resample: PILImageResampling = None,
|
| 229 |
+
do_center_crop: bool = None,
|
| 230 |
+
crop_size: Dict[str, int] = None,
|
| 231 |
+
do_rescale: bool = None,
|
| 232 |
+
rescale_factor: float = None,
|
| 233 |
+
do_normalize: bool = None,
|
| 234 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 235 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 236 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 237 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 238 |
+
) -> np.ndarray:
|
| 239 |
+
"""Preprocesses a single image."""
|
| 240 |
+
# All transformations expect numpy arrays.
|
| 241 |
+
image = to_numpy_array(image)
|
| 242 |
+
if do_rescale and is_scaled_image(image):
|
| 243 |
+
logger.warning_once(
|
| 244 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 245 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 246 |
+
)
|
| 247 |
+
if input_data_format is None:
|
| 248 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 249 |
+
image = self._preprocess(
|
| 250 |
+
image,
|
| 251 |
+
do_reduce_labels=False,
|
| 252 |
+
do_resize=do_resize,
|
| 253 |
+
size=size,
|
| 254 |
+
resample=resample,
|
| 255 |
+
do_center_crop=do_center_crop,
|
| 256 |
+
crop_size=crop_size,
|
| 257 |
+
do_rescale=do_rescale,
|
| 258 |
+
rescale_factor=rescale_factor,
|
| 259 |
+
do_normalize=do_normalize,
|
| 260 |
+
image_mean=image_mean,
|
| 261 |
+
image_std=image_std,
|
| 262 |
+
input_data_format=input_data_format,
|
| 263 |
+
)
|
| 264 |
+
if data_format is not None:
|
| 265 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 266 |
+
return image
|
| 267 |
+
|
| 268 |
+
def _preprocess_segmentation_map(
|
| 269 |
+
self,
|
| 270 |
+
segmentation_map: ImageInput,
|
| 271 |
+
do_resize: bool = None,
|
| 272 |
+
size: Dict[str, int] = None,
|
| 273 |
+
resample: PILImageResampling = None,
|
| 274 |
+
do_center_crop: bool = None,
|
| 275 |
+
crop_size: Dict[str, int] = None,
|
| 276 |
+
do_reduce_labels: bool = None,
|
| 277 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 278 |
+
):
|
| 279 |
+
"""Preprocesses a single segmentation map."""
|
| 280 |
+
# All transformations expect numpy arrays.
|
| 281 |
+
segmentation_map = to_numpy_array(segmentation_map)
|
| 282 |
+
# Add an axis to the segmentation maps for transformations.
|
| 283 |
+
if segmentation_map.ndim == 2:
|
| 284 |
+
segmentation_map = segmentation_map[None, ...]
|
| 285 |
+
added_dimension = True
|
| 286 |
+
input_data_format = ChannelDimension.FIRST
|
| 287 |
+
else:
|
| 288 |
+
added_dimension = False
|
| 289 |
+
if input_data_format is None:
|
| 290 |
+
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
|
| 291 |
+
segmentation_map = self._preprocess(
|
| 292 |
+
image=segmentation_map,
|
| 293 |
+
do_reduce_labels=do_reduce_labels,
|
| 294 |
+
do_resize=do_resize,
|
| 295 |
+
resample=resample,
|
| 296 |
+
size=size,
|
| 297 |
+
do_center_crop=do_center_crop,
|
| 298 |
+
crop_size=crop_size,
|
| 299 |
+
do_normalize=False,
|
| 300 |
+
do_rescale=False,
|
| 301 |
+
input_data_format=ChannelDimension.FIRST,
|
| 302 |
+
)
|
| 303 |
+
# Remove extra axis if added
|
| 304 |
+
if added_dimension:
|
| 305 |
+
segmentation_map = np.squeeze(segmentation_map, axis=0)
|
| 306 |
+
segmentation_map = segmentation_map.astype(np.int64)
|
| 307 |
+
return segmentation_map
|
| 308 |
+
|
| 309 |
+
def __call__(self, images, segmentation_maps=None, **kwargs):
|
| 310 |
+
# Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both
|
| 311 |
+
# be passed in as positional arguments.
|
| 312 |
+
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
|
| 313 |
+
|
| 314 |
+
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0")
|
| 315 |
+
@filter_out_non_signature_kwargs()
|
| 316 |
+
def preprocess(
|
| 317 |
+
self,
|
| 318 |
+
images: ImageInput,
|
| 319 |
+
segmentation_maps: Optional[ImageInput] = None,
|
| 320 |
+
do_resize: bool = None,
|
| 321 |
+
size: Dict[str, int] = None,
|
| 322 |
+
resample: PILImageResampling = None,
|
| 323 |
+
do_center_crop: bool = None,
|
| 324 |
+
crop_size: Dict[str, int] = None,
|
| 325 |
+
do_rescale: bool = None,
|
| 326 |
+
rescale_factor: float = None,
|
| 327 |
+
do_normalize: bool = None,
|
| 328 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 329 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 330 |
+
do_reduce_labels: Optional[bool] = None,
|
| 331 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 332 |
+
data_format: ChannelDimension = ChannelDimension.FIRST,
|
| 333 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 334 |
+
) -> PIL.Image.Image:
|
| 335 |
+
"""
|
| 336 |
+
Preprocess an image or batch of images.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
images (`ImageInput`):
|
| 340 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 341 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 342 |
+
segmentation_maps (`ImageInput`, *optional*)
|
| 343 |
+
Segmentation maps to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 344 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 345 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 346 |
+
Whether to resize the image.
|
| 347 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 348 |
+
Size of the image after resizing.
|
| 349 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
| 350 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
| 351 |
+
has an effect if `do_resize` is set to `True`.
|
| 352 |
+
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
|
| 353 |
+
Whether to center crop the image.
|
| 354 |
+
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
| 355 |
+
Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
|
| 356 |
+
padded with zeros and then cropped
|
| 357 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 358 |
+
Whether to rescale the image values between [0 - 1].
|
| 359 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 360 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 361 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 362 |
+
Whether to normalize the image.
|
| 363 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 364 |
+
Image mean.
|
| 365 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 366 |
+
Image standard deviation.
|
| 367 |
+
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
|
| 368 |
+
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
|
| 369 |
+
is used for background, and background itself is not included in all classes of a dataset (e.g.
|
| 370 |
+
ADE20k). The background label will be replaced by 255.
|
| 371 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 372 |
+
The type of tensors to return. Can be one of:
|
| 373 |
+
- Unset: Return a list of `np.ndarray`.
|
| 374 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 375 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 376 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 377 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 378 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 379 |
+
The channel dimension format for the output image. Can be one of:
|
| 380 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 381 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 382 |
+
- Unset: Use the channel dimension format of the input image.
|
| 383 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 384 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 385 |
+
from the input image. Can be one of:
|
| 386 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 387 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 388 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 389 |
+
"""
|
| 390 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 391 |
+
size = size if size is not None else self.size
|
| 392 |
+
size = get_size_dict(size, default_to_square=True, param_name="size")
|
| 393 |
+
resample = resample if resample is not None else self.resample
|
| 394 |
+
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
| 395 |
+
crop_size = crop_size if crop_size is not None else self.crop_size
|
| 396 |
+
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
|
| 397 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 398 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 399 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 400 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 401 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 402 |
+
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
|
| 403 |
+
|
| 404 |
+
images = make_list_of_images(images)
|
| 405 |
+
|
| 406 |
+
if segmentation_maps is not None:
|
| 407 |
+
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
|
| 408 |
+
|
| 409 |
+
if segmentation_maps is not None and not valid_images(segmentation_maps):
|
| 410 |
+
raise ValueError(
|
| 411 |
+
"Invalid segmentation_maps type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 412 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 413 |
+
)
|
| 414 |
+
if not valid_images(images):
|
| 415 |
+
raise ValueError(
|
| 416 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 417 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
validate_preprocess_arguments(
|
| 421 |
+
do_rescale=do_rescale,
|
| 422 |
+
rescale_factor=rescale_factor,
|
| 423 |
+
do_normalize=do_normalize,
|
| 424 |
+
image_mean=image_mean,
|
| 425 |
+
image_std=image_std,
|
| 426 |
+
do_center_crop=do_center_crop,
|
| 427 |
+
crop_size=crop_size,
|
| 428 |
+
do_resize=do_resize,
|
| 429 |
+
size=size,
|
| 430 |
+
resample=resample,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
images = [
|
| 434 |
+
self._preprocess_image(
|
| 435 |
+
image=img,
|
| 436 |
+
do_resize=do_resize,
|
| 437 |
+
do_center_crop=do_center_crop,
|
| 438 |
+
do_rescale=do_rescale,
|
| 439 |
+
do_normalize=do_normalize,
|
| 440 |
+
resample=resample,
|
| 441 |
+
size=size,
|
| 442 |
+
rescale_factor=rescale_factor,
|
| 443 |
+
crop_size=crop_size,
|
| 444 |
+
image_mean=image_mean,
|
| 445 |
+
image_std=image_std,
|
| 446 |
+
data_format=data_format,
|
| 447 |
+
input_data_format=input_data_format,
|
| 448 |
+
)
|
| 449 |
+
for img in images
|
| 450 |
+
]
|
| 451 |
+
|
| 452 |
+
data = {"pixel_values": images}
|
| 453 |
+
|
| 454 |
+
if segmentation_maps is not None:
|
| 455 |
+
segmentation_maps = [
|
| 456 |
+
self._preprocess_segmentation_map(
|
| 457 |
+
segmentation_map=segmentation_map,
|
| 458 |
+
do_reduce_labels=do_reduce_labels,
|
| 459 |
+
do_resize=do_resize,
|
| 460 |
+
resample=resample,
|
| 461 |
+
size=size,
|
| 462 |
+
do_center_crop=do_center_crop,
|
| 463 |
+
crop_size=crop_size,
|
| 464 |
+
)
|
| 465 |
+
for segmentation_map in segmentation_maps
|
| 466 |
+
]
|
| 467 |
+
data["labels"] = segmentation_maps
|
| 468 |
+
|
| 469 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 470 |
+
|
| 471 |
+
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
|
| 472 |
+
"""
|
| 473 |
+
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
outputs ([`BeitForSemanticSegmentation`]):
|
| 477 |
+
Raw outputs of the model.
|
| 478 |
+
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
|
| 479 |
+
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
|
| 480 |
+
predictions will not be resized.
|
| 481 |
+
|
| 482 |
+
Returns:
|
| 483 |
+
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
| 484 |
+
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
| 485 |
+
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
| 486 |
+
"""
|
| 487 |
+
# TODO: add support for other frameworks
|
| 488 |
+
logits = outputs.logits
|
| 489 |
+
|
| 490 |
+
# Resize logits and compute semantic segmentation maps
|
| 491 |
+
if target_sizes is not None:
|
| 492 |
+
if len(logits) != len(target_sizes):
|
| 493 |
+
raise ValueError(
|
| 494 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
if is_torch_tensor(target_sizes):
|
| 498 |
+
target_sizes = target_sizes.numpy()
|
| 499 |
+
|
| 500 |
+
semantic_segmentation = []
|
| 501 |
+
|
| 502 |
+
for idx in range(len(logits)):
|
| 503 |
+
resized_logits = torch.nn.functional.interpolate(
|
| 504 |
+
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
| 505 |
+
)
|
| 506 |
+
semantic_map = resized_logits[0].argmax(dim=0)
|
| 507 |
+
semantic_segmentation.append(semantic_map)
|
| 508 |
+
else:
|
| 509 |
+
semantic_segmentation = logits.argmax(dim=1)
|
| 510 |
+
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
| 511 |
+
|
| 512 |
+
return semantic_segmentation
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
__all__ = ["BeitImageProcessor"]
|
.venv/lib/python3.11/site-packages/transformers/models/beit/modeling_flax_beit.py
ADDED
|
@@ -0,0 +1,956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 Microsoft Research and the HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
from typing import Callable, List, Optional, Tuple
|
| 18 |
+
|
| 19 |
+
import flax
|
| 20 |
+
import flax.linen as nn
|
| 21 |
+
import jax
|
| 22 |
+
import jax.numpy as jnp
|
| 23 |
+
import numpy as np
|
| 24 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
| 25 |
+
from flax.linen.attention import dot_product_attention_weights
|
| 26 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
| 27 |
+
|
| 28 |
+
from ...modeling_flax_outputs import (
|
| 29 |
+
FlaxBaseModelOutput,
|
| 30 |
+
FlaxBaseModelOutputWithPooling,
|
| 31 |
+
FlaxMaskedLMOutput,
|
| 32 |
+
FlaxSequenceClassifierOutput,
|
| 33 |
+
)
|
| 34 |
+
from ...modeling_flax_utils import (
|
| 35 |
+
ACT2FN,
|
| 36 |
+
FlaxPreTrainedModel,
|
| 37 |
+
append_replace_return_docstrings,
|
| 38 |
+
overwrite_call_docstring,
|
| 39 |
+
)
|
| 40 |
+
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
| 41 |
+
from .configuration_beit import BeitConfig
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@flax.struct.dataclass
|
| 45 |
+
class FlaxBeitModelOutputWithPooling(FlaxBaseModelOutputWithPooling):
|
| 46 |
+
"""
|
| 47 |
+
Class for outputs of [`FlaxBeitModel`].
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 51 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 52 |
+
pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
|
| 53 |
+
Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
|
| 54 |
+
*config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
|
| 55 |
+
will be returned.
|
| 56 |
+
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 57 |
+
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 58 |
+
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
|
| 59 |
+
the initial embedding outputs.
|
| 60 |
+
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 61 |
+
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 62 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 63 |
+
the self-attention heads.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
BEIT_START_DOCSTRING = r"""
|
| 68 |
+
|
| 69 |
+
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 70 |
+
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
|
| 71 |
+
|
| 72 |
+
This model is also a
|
| 73 |
+
[flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
|
| 74 |
+
a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
|
| 75 |
+
behavior.
|
| 76 |
+
|
| 77 |
+
Finally, this model supports inherent JAX features such as:
|
| 78 |
+
|
| 79 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
| 80 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
| 81 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
| 82 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
| 83 |
+
|
| 84 |
+
Parameters:
|
| 85 |
+
config ([`BeitConfig`]): Model configuration class with all the parameters of the model.
|
| 86 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 87 |
+
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
| 88 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
| 89 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
| 90 |
+
`jax.numpy.bfloat16` (on TPUs).
|
| 91 |
+
|
| 92 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
| 93 |
+
specified all the computation will be performed with the given `dtype`.
|
| 94 |
+
|
| 95 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
| 96 |
+
parameters.**
|
| 97 |
+
|
| 98 |
+
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
| 99 |
+
[`~FlaxPreTrainedModel.to_bf16`].
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
BEIT_INPUTS_DOCSTRING = r"""
|
| 103 |
+
Args:
|
| 104 |
+
pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
|
| 105 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
| 106 |
+
[`AutoImageProcessor.__call__`] for details.
|
| 107 |
+
|
| 108 |
+
output_attentions (`bool`, *optional*):
|
| 109 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 110 |
+
tensors for more detail.
|
| 111 |
+
output_hidden_states (`bool`, *optional*):
|
| 112 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 113 |
+
more detail.
|
| 114 |
+
return_dict (`bool`, *optional*):
|
| 115 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def relative_position_index_init(window_size: Tuple[int, int]) -> jnp.ndarray:
|
| 120 |
+
"""
|
| 121 |
+
get pair-wise relative position index for each token inside the window
|
| 122 |
+
"""
|
| 123 |
+
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
| 124 |
+
|
| 125 |
+
coords_h = np.arange(window_size[0])
|
| 126 |
+
coords_w = np.arange(window_size[1])
|
| 127 |
+
coords = np.stack(np.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww
|
| 128 |
+
coords_flatten = np.reshape(coords, (2, -1))
|
| 129 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 130 |
+
relative_coords = np.transpose(relative_coords, (1, 2, 0)) # Wh*Ww, Wh*Ww, 2
|
| 131 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
| 132 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
| 133 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
| 134 |
+
|
| 135 |
+
relative_position_index = np.zeros(shape=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
| 136 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 137 |
+
relative_position_index[0, 0:] = num_relative_distance - 3
|
| 138 |
+
relative_position_index[0:, 0] = num_relative_distance - 2
|
| 139 |
+
relative_position_index[0, 0] = num_relative_distance - 1
|
| 140 |
+
return jnp.array(relative_position_index)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def ones_with_scale(key, shape, scale, dtype=jnp.float32):
|
| 144 |
+
return jnp.ones(shape, dtype) * scale
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class FlaxBeitDropPath(nn.Module):
|
| 148 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 149 |
+
|
| 150 |
+
rate: float
|
| 151 |
+
|
| 152 |
+
@nn.module.compact
|
| 153 |
+
def __call__(self, inputs, deterministic: Optional[bool] = True):
|
| 154 |
+
if self.rate == 0.0:
|
| 155 |
+
return inputs
|
| 156 |
+
keep_prob = 1.0 - self.rate
|
| 157 |
+
if deterministic:
|
| 158 |
+
return inputs
|
| 159 |
+
else:
|
| 160 |
+
shape = (inputs.shape[0],) + (1,) * (inputs.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 161 |
+
rng = self.make_rng("droppath")
|
| 162 |
+
random_tensor = keep_prob + jax.random.uniform(rng, shape=shape, dtype=inputs.dtype)
|
| 163 |
+
binary_tensor = jnp.floor(random_tensor)
|
| 164 |
+
output = inputs / keep_prob * binary_tensor
|
| 165 |
+
return output
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class FlaxBeitPatchEmbeddings(nn.Module):
|
| 169 |
+
config: BeitConfig
|
| 170 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 171 |
+
|
| 172 |
+
def setup(self):
|
| 173 |
+
self.num_channels = self.config.num_channels
|
| 174 |
+
image_size = self.config.image_size
|
| 175 |
+
patch_size = self.config.patch_size
|
| 176 |
+
num_patches = (image_size // patch_size) * (image_size // patch_size)
|
| 177 |
+
patch_shape = (image_size // patch_size, image_size // patch_size)
|
| 178 |
+
self.num_patches = num_patches
|
| 179 |
+
self.patch_shape = patch_shape
|
| 180 |
+
self.projection = nn.Conv(
|
| 181 |
+
self.config.hidden_size,
|
| 182 |
+
kernel_size=(patch_size, patch_size),
|
| 183 |
+
strides=(patch_size, patch_size),
|
| 184 |
+
padding="VALID",
|
| 185 |
+
dtype=self.dtype,
|
| 186 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def __call__(self, pixel_values):
|
| 190 |
+
num_channels = pixel_values.shape[-1]
|
| 191 |
+
if num_channels != self.num_channels:
|
| 192 |
+
raise ValueError(
|
| 193 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 194 |
+
)
|
| 195 |
+
embeddings = self.projection(pixel_values)
|
| 196 |
+
batch_size, _, _, channels = embeddings.shape
|
| 197 |
+
return jnp.reshape(embeddings, (batch_size, -1, channels))
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class FlaxBeitEmbeddings(nn.Module):
|
| 201 |
+
"""Construct the CLS token, position and patch embeddings."""
|
| 202 |
+
|
| 203 |
+
config: BeitConfig
|
| 204 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 205 |
+
|
| 206 |
+
def setup(self):
|
| 207 |
+
self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
|
| 208 |
+
if self.config.use_mask_token:
|
| 209 |
+
self.mask_token = self.param("mask_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
|
| 210 |
+
self.patch_embeddings = FlaxBeitPatchEmbeddings(self.config, dtype=self.dtype)
|
| 211 |
+
num_patches = self.patch_embeddings.num_patches
|
| 212 |
+
if self.config.use_absolute_position_embeddings:
|
| 213 |
+
self.position_embeddings = self.param(
|
| 214 |
+
"position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size)
|
| 215 |
+
)
|
| 216 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 217 |
+
|
| 218 |
+
def __call__(self, pixel_values, bool_masked_pos=None, deterministic=True):
|
| 219 |
+
embeddings = self.patch_embeddings(pixel_values)
|
| 220 |
+
batch_size, seq_len, _ = embeddings.shape
|
| 221 |
+
|
| 222 |
+
cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size))
|
| 223 |
+
cls_tokens = cls_tokens.astype(embeddings.dtype)
|
| 224 |
+
|
| 225 |
+
if bool_masked_pos is not None:
|
| 226 |
+
mask_tokens = jnp.broadcast_to(self.mask_token, (batch_size, seq_len, self.config.hidden_size))
|
| 227 |
+
mask_tokens = mask_tokens.astype(embeddings.dtype)
|
| 228 |
+
# replace the masked visual tokens by mask_tokens
|
| 229 |
+
w = jnp.expand_dims(bool_masked_pos, axis=-1)
|
| 230 |
+
embeddings = embeddings * (1 - w) + mask_tokens * w
|
| 231 |
+
|
| 232 |
+
embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1)
|
| 233 |
+
|
| 234 |
+
if self.config.use_absolute_position_embeddings:
|
| 235 |
+
embeddings = embeddings + self.position_embeddings.astype(embeddings.dtype)
|
| 236 |
+
|
| 237 |
+
embeddings = self.dropout(embeddings, deterministic=deterministic)
|
| 238 |
+
return embeddings
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class FlaxBeitRelativePositionBias(nn.Module):
|
| 242 |
+
config: BeitConfig
|
| 243 |
+
window_size: Tuple[int, int]
|
| 244 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 245 |
+
|
| 246 |
+
def setup(self):
|
| 247 |
+
num_relative_distance = (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) + 3
|
| 248 |
+
self.relative_position_bias_table = self.param(
|
| 249 |
+
"relative_position_bias_table",
|
| 250 |
+
nn.initializers.zeros,
|
| 251 |
+
(num_relative_distance, self.config.num_attention_heads),
|
| 252 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
| 253 |
+
# cls to token & token 2 cls & cls to cls
|
| 254 |
+
|
| 255 |
+
self.relative_position_index = relative_position_index_init(self.window_size)
|
| 256 |
+
|
| 257 |
+
def __call__(self):
|
| 258 |
+
index = self.relative_position_index.reshape(-1)
|
| 259 |
+
shape = (self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1)
|
| 260 |
+
relative_position_bias = self.relative_position_bias_table[index].reshape(shape) # Wh*Ww,Wh*Ww,nH
|
| 261 |
+
return jnp.transpose(relative_position_bias, (2, 0, 1))
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class FlaxBeitSelfAttention(nn.Module):
|
| 265 |
+
config: BeitConfig
|
| 266 |
+
window_size: Tuple[int, int]
|
| 267 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 268 |
+
|
| 269 |
+
def setup(self):
|
| 270 |
+
if self.config.hidden_size % self.config.num_attention_heads != 0 and not hasattr(
|
| 271 |
+
self.config, "embedding_size"
|
| 272 |
+
):
|
| 273 |
+
raise ValueError(
|
| 274 |
+
f"The hidden size {self.config.hidden_size,} is not a multiple of the number of attention "
|
| 275 |
+
f"heads {self.config.num_attention_heads}."
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
self.query = nn.Dense(
|
| 279 |
+
self.config.hidden_size,
|
| 280 |
+
dtype=self.dtype,
|
| 281 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 282 |
+
)
|
| 283 |
+
self.key = nn.Dense(
|
| 284 |
+
self.config.hidden_size,
|
| 285 |
+
dtype=self.dtype,
|
| 286 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 287 |
+
use_bias=False,
|
| 288 |
+
)
|
| 289 |
+
self.value = nn.Dense(
|
| 290 |
+
self.config.hidden_size,
|
| 291 |
+
dtype=self.dtype,
|
| 292 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
self.relative_position_bias = (
|
| 296 |
+
FlaxBeitRelativePositionBias(self.config, window_size=self.window_size, dtype=self.dtype)
|
| 297 |
+
if self.window_size
|
| 298 |
+
else None
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
def __call__(
|
| 302 |
+
self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False
|
| 303 |
+
):
|
| 304 |
+
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
| 305 |
+
|
| 306 |
+
query_states = self.query(hidden_states).reshape(
|
| 307 |
+
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
| 308 |
+
)
|
| 309 |
+
value_states = self.value(hidden_states).reshape(
|
| 310 |
+
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
| 311 |
+
)
|
| 312 |
+
key_states = self.key(hidden_states).reshape(
|
| 313 |
+
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
dropout_rng = None
|
| 317 |
+
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
|
| 318 |
+
dropout_rng = self.make_rng("dropout")
|
| 319 |
+
|
| 320 |
+
attention_bias = jnp.array(0.0, dtype=self.dtype)
|
| 321 |
+
# Add relative position bias if present.
|
| 322 |
+
if self.relative_position_bias is not None:
|
| 323 |
+
attention_bias = jnp.expand_dims(self.relative_position_bias(), 0)
|
| 324 |
+
attention_bias = attention_bias.astype(query_states.dtype)
|
| 325 |
+
|
| 326 |
+
# Add shared relative position bias if provided.
|
| 327 |
+
if relative_position_bias is not None:
|
| 328 |
+
attention_bias = attention_bias + relative_position_bias.astype(attention_bias.dtype)
|
| 329 |
+
|
| 330 |
+
attn_weights = dot_product_attention_weights(
|
| 331 |
+
query_states,
|
| 332 |
+
key_states,
|
| 333 |
+
bias=attention_bias,
|
| 334 |
+
dropout_rng=dropout_rng,
|
| 335 |
+
dropout_rate=self.config.attention_probs_dropout_prob,
|
| 336 |
+
broadcast_dropout=True,
|
| 337 |
+
deterministic=deterministic,
|
| 338 |
+
dtype=self.dtype,
|
| 339 |
+
precision=None,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
| 343 |
+
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
| 344 |
+
|
| 345 |
+
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
| 346 |
+
return outputs
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class FlaxBeitSelfOutput(nn.Module):
|
| 350 |
+
config: BeitConfig
|
| 351 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 352 |
+
|
| 353 |
+
def setup(self):
|
| 354 |
+
self.dense = nn.Dense(
|
| 355 |
+
self.config.hidden_size,
|
| 356 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 357 |
+
dtype=self.dtype,
|
| 358 |
+
)
|
| 359 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 360 |
+
|
| 361 |
+
def __call__(self, hidden_states, deterministic: bool = True):
|
| 362 |
+
hidden_states = self.dense(hidden_states)
|
| 363 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 364 |
+
return hidden_states
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class FlaxBeitAttention(nn.Module):
|
| 368 |
+
config: BeitConfig
|
| 369 |
+
window_size: Tuple[int, int]
|
| 370 |
+
dtype: jnp.dtype = jnp.float32
|
| 371 |
+
|
| 372 |
+
def setup(self):
|
| 373 |
+
self.attention = FlaxBeitSelfAttention(self.config, self.window_size, dtype=self.dtype)
|
| 374 |
+
self.output = FlaxBeitSelfOutput(self.config, dtype=self.dtype)
|
| 375 |
+
|
| 376 |
+
def __call__(
|
| 377 |
+
self, hidden_states, relative_position_bias=None, deterministic=True, output_attentions: bool = False
|
| 378 |
+
):
|
| 379 |
+
attn_outputs = self.attention(
|
| 380 |
+
hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions
|
| 381 |
+
)
|
| 382 |
+
attn_output = attn_outputs[0]
|
| 383 |
+
attn_output = self.output(attn_output, deterministic=deterministic)
|
| 384 |
+
|
| 385 |
+
outputs = (attn_output,)
|
| 386 |
+
|
| 387 |
+
if output_attentions:
|
| 388 |
+
outputs += (attn_outputs[1],)
|
| 389 |
+
|
| 390 |
+
return outputs
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
class FlaxBeitIntermediate(nn.Module):
|
| 394 |
+
config: BeitConfig
|
| 395 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 396 |
+
|
| 397 |
+
def setup(self):
|
| 398 |
+
self.dense = nn.Dense(
|
| 399 |
+
self.config.intermediate_size,
|
| 400 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 401 |
+
dtype=self.dtype,
|
| 402 |
+
)
|
| 403 |
+
self.activation = ACT2FN[self.config.hidden_act]
|
| 404 |
+
|
| 405 |
+
def __call__(self, hidden_states):
|
| 406 |
+
hidden_states = self.dense(hidden_states)
|
| 407 |
+
hidden_states = self.activation(hidden_states)
|
| 408 |
+
|
| 409 |
+
return hidden_states
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class FlaxBeitOutput(nn.Module):
|
| 413 |
+
config: BeitConfig
|
| 414 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 415 |
+
|
| 416 |
+
def setup(self):
|
| 417 |
+
self.dense = nn.Dense(
|
| 418 |
+
self.config.hidden_size,
|
| 419 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 420 |
+
dtype=self.dtype,
|
| 421 |
+
)
|
| 422 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 423 |
+
|
| 424 |
+
def __call__(self, hidden_states, deterministic: bool = True):
|
| 425 |
+
hidden_states = self.dense(hidden_states)
|
| 426 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 427 |
+
|
| 428 |
+
return hidden_states
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class FlaxBeitLayer(nn.Module):
|
| 432 |
+
config: BeitConfig
|
| 433 |
+
window_size: Tuple[int, int]
|
| 434 |
+
drop_path_rate: float
|
| 435 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 436 |
+
|
| 437 |
+
def setup(self):
|
| 438 |
+
self.attention = FlaxBeitAttention(self.config, self.window_size, dtype=self.dtype)
|
| 439 |
+
self.intermediate = FlaxBeitIntermediate(self.config, dtype=self.dtype)
|
| 440 |
+
self.output = FlaxBeitOutput(self.config, dtype=self.dtype)
|
| 441 |
+
self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 442 |
+
self.drop_path = FlaxBeitDropPath(rate=self.drop_path_rate)
|
| 443 |
+
self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 444 |
+
|
| 445 |
+
self.init_values = self.config.layer_scale_init_value
|
| 446 |
+
if self.init_values > 0:
|
| 447 |
+
self.lambda_1 = self.param("lambda_1", ones_with_scale, (self.config.hidden_size), self.init_values)
|
| 448 |
+
self.lambda_2 = self.param("lambda_2", ones_with_scale, (self.config.hidden_size), self.init_values)
|
| 449 |
+
else:
|
| 450 |
+
self.lambda_1 = None
|
| 451 |
+
self.lambda_2 = None
|
| 452 |
+
|
| 453 |
+
def __call__(
|
| 454 |
+
self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False
|
| 455 |
+
):
|
| 456 |
+
self_attention_outputs = self.attention(
|
| 457 |
+
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
|
| 458 |
+
relative_position_bias,
|
| 459 |
+
deterministic=deterministic,
|
| 460 |
+
output_attentions=output_attentions,
|
| 461 |
+
)
|
| 462 |
+
attention_output = self_attention_outputs[0]
|
| 463 |
+
|
| 464 |
+
# apply lambda_1 if present
|
| 465 |
+
if self.lambda_1 is not None:
|
| 466 |
+
attention_output = self.lambda_1.astype(attention_output.dtype) * attention_output
|
| 467 |
+
|
| 468 |
+
# first residual connection
|
| 469 |
+
hidden_states = self.drop_path(attention_output, deterministic=deterministic) + hidden_states
|
| 470 |
+
|
| 471 |
+
# in BEiT, layernorm is also applied after self-attention
|
| 472 |
+
layer_output = self.layernorm_after(hidden_states)
|
| 473 |
+
|
| 474 |
+
layer_output = self.intermediate(layer_output)
|
| 475 |
+
layer_output = self.output(layer_output, deterministic=deterministic)
|
| 476 |
+
|
| 477 |
+
# apply lambda_2 if present
|
| 478 |
+
if self.lambda_2 is not None:
|
| 479 |
+
layer_output = self.lambda_2.astype(layer_output.dtype) * layer_output
|
| 480 |
+
|
| 481 |
+
# second residual connection
|
| 482 |
+
layer_output = self.drop_path(layer_output, deterministic=deterministic) + hidden_states
|
| 483 |
+
|
| 484 |
+
outputs = (layer_output,)
|
| 485 |
+
|
| 486 |
+
if output_attentions:
|
| 487 |
+
outputs += (self_attention_outputs[1],)
|
| 488 |
+
|
| 489 |
+
return outputs
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class FlaxBeitLayerCollection(nn.Module):
|
| 493 |
+
config: BeitConfig
|
| 494 |
+
window_size: Tuple[int, int]
|
| 495 |
+
drop_path_rates: List[float]
|
| 496 |
+
relative_position_bias: Callable[[], jnp.ndarray]
|
| 497 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 498 |
+
|
| 499 |
+
def setup(self):
|
| 500 |
+
self.layers = [
|
| 501 |
+
FlaxBeitLayer(
|
| 502 |
+
self.config,
|
| 503 |
+
window_size=self.window_size if self.config.use_relative_position_bias else None,
|
| 504 |
+
drop_path_rate=self.drop_path_rates[i],
|
| 505 |
+
name=str(i),
|
| 506 |
+
dtype=self.dtype,
|
| 507 |
+
)
|
| 508 |
+
for i in range(self.config.num_hidden_layers)
|
| 509 |
+
]
|
| 510 |
+
|
| 511 |
+
def __call__(
|
| 512 |
+
self,
|
| 513 |
+
hidden_states,
|
| 514 |
+
deterministic: bool = True,
|
| 515 |
+
output_attentions: bool = False,
|
| 516 |
+
output_hidden_states: bool = False,
|
| 517 |
+
return_dict: bool = True,
|
| 518 |
+
):
|
| 519 |
+
all_attentions = () if output_attentions else None
|
| 520 |
+
all_hidden_states = () if output_hidden_states else None
|
| 521 |
+
|
| 522 |
+
for i, layer in enumerate(self.layers):
|
| 523 |
+
if output_hidden_states:
|
| 524 |
+
all_hidden_states += (hidden_states,)
|
| 525 |
+
relative_position_bias = self.relative_position_bias() if self.relative_position_bias is not None else None
|
| 526 |
+
layer_outputs = layer(
|
| 527 |
+
hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
hidden_states = layer_outputs[0]
|
| 531 |
+
|
| 532 |
+
if output_attentions:
|
| 533 |
+
all_attentions += (layer_outputs[1],)
|
| 534 |
+
|
| 535 |
+
if output_hidden_states:
|
| 536 |
+
all_hidden_states += (hidden_states,)
|
| 537 |
+
|
| 538 |
+
outputs = (hidden_states,)
|
| 539 |
+
if not return_dict:
|
| 540 |
+
return tuple(v for v in outputs if v is not None)
|
| 541 |
+
|
| 542 |
+
return FlaxBaseModelOutput(
|
| 543 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
class FlaxBeitEncoder(nn.Module):
|
| 548 |
+
config: BeitConfig
|
| 549 |
+
window_size: Tuple[int, int]
|
| 550 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 551 |
+
|
| 552 |
+
def setup(self):
|
| 553 |
+
if self.config.use_shared_relative_position_bias:
|
| 554 |
+
self.relative_position_bias = FlaxBeitRelativePositionBias(
|
| 555 |
+
config=self.config, window_size=self.window_size, dtype=self.dtype
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
# stochastic depth decay rule
|
| 559 |
+
drop_path_rates = list(np.linspace(0, self.config.drop_path_rate, self.config.num_hidden_layers))
|
| 560 |
+
self.layer = FlaxBeitLayerCollection(
|
| 561 |
+
self.config,
|
| 562 |
+
window_size=self.window_size,
|
| 563 |
+
drop_path_rates=drop_path_rates,
|
| 564 |
+
relative_position_bias=self.relative_position_bias
|
| 565 |
+
if self.config.use_shared_relative_position_bias
|
| 566 |
+
else None,
|
| 567 |
+
dtype=self.dtype,
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
def __call__(
|
| 571 |
+
self,
|
| 572 |
+
hidden_states,
|
| 573 |
+
deterministic: bool = True,
|
| 574 |
+
output_attentions: bool = False,
|
| 575 |
+
output_hidden_states: bool = False,
|
| 576 |
+
return_dict: bool = True,
|
| 577 |
+
):
|
| 578 |
+
return self.layer(
|
| 579 |
+
hidden_states,
|
| 580 |
+
deterministic=deterministic,
|
| 581 |
+
output_attentions=output_attentions,
|
| 582 |
+
output_hidden_states=output_hidden_states,
|
| 583 |
+
return_dict=return_dict,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
|
| 588 |
+
"""
|
| 589 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 590 |
+
models.
|
| 591 |
+
"""
|
| 592 |
+
|
| 593 |
+
config_class = BeitConfig
|
| 594 |
+
base_model_prefix = "beit"
|
| 595 |
+
main_input_name = "pixel_values"
|
| 596 |
+
module_class: nn.Module = None
|
| 597 |
+
|
| 598 |
+
def __init__(
|
| 599 |
+
self,
|
| 600 |
+
config: BeitConfig,
|
| 601 |
+
input_shape=None,
|
| 602 |
+
seed: int = 0,
|
| 603 |
+
dtype: jnp.dtype = jnp.float32,
|
| 604 |
+
_do_init: bool = True,
|
| 605 |
+
**kwargs,
|
| 606 |
+
):
|
| 607 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
| 608 |
+
if input_shape is None:
|
| 609 |
+
input_shape = (1, config.image_size, config.image_size, config.num_channels)
|
| 610 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
| 611 |
+
|
| 612 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
| 613 |
+
# init input tensors
|
| 614 |
+
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
|
| 615 |
+
|
| 616 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
| 617 |
+
dropout_rng, droppath_rng = jax.random.split(dropout_rng)
|
| 618 |
+
rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng}
|
| 619 |
+
|
| 620 |
+
random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]
|
| 621 |
+
|
| 622 |
+
if params is not None:
|
| 623 |
+
random_params = flatten_dict(unfreeze(random_params))
|
| 624 |
+
params = flatten_dict(unfreeze(params))
|
| 625 |
+
for missing_key in self._missing_keys:
|
| 626 |
+
params[missing_key] = random_params[missing_key]
|
| 627 |
+
self._missing_keys = set()
|
| 628 |
+
return freeze(unflatten_dict(params))
|
| 629 |
+
else:
|
| 630 |
+
return random_params
|
| 631 |
+
|
| 632 |
+
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 633 |
+
def __call__(
|
| 634 |
+
self,
|
| 635 |
+
pixel_values,
|
| 636 |
+
bool_masked_pos=None,
|
| 637 |
+
params: dict = None,
|
| 638 |
+
dropout_rng: jax.random.PRNGKey = None,
|
| 639 |
+
train: bool = False,
|
| 640 |
+
output_attentions: Optional[bool] = None,
|
| 641 |
+
output_hidden_states: Optional[bool] = None,
|
| 642 |
+
return_dict: Optional[bool] = None,
|
| 643 |
+
):
|
| 644 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 645 |
+
output_hidden_states = (
|
| 646 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 647 |
+
)
|
| 648 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 649 |
+
|
| 650 |
+
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
|
| 651 |
+
# Handle any PRNG if needed
|
| 652 |
+
rngs = {}
|
| 653 |
+
if dropout_rng is not None:
|
| 654 |
+
dropout_rng, droppath_rng = jax.random.split(dropout_rng)
|
| 655 |
+
rngs["dropout"] = dropout_rng
|
| 656 |
+
rngs["droppath"] = droppath_rng
|
| 657 |
+
|
| 658 |
+
return self.module.apply(
|
| 659 |
+
{"params": params or self.params},
|
| 660 |
+
jnp.array(pixel_values, dtype=jnp.float32),
|
| 661 |
+
bool_masked_pos,
|
| 662 |
+
not train,
|
| 663 |
+
output_attentions,
|
| 664 |
+
output_hidden_states,
|
| 665 |
+
return_dict,
|
| 666 |
+
rngs=rngs,
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
class FlaxBeitPooler(nn.Module):
|
| 671 |
+
config: BeitConfig
|
| 672 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 673 |
+
|
| 674 |
+
def setup(self):
|
| 675 |
+
if self.config.use_mean_pooling:
|
| 676 |
+
self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 677 |
+
|
| 678 |
+
def __call__(self, hidden_states):
|
| 679 |
+
if self.config.use_mean_pooling:
|
| 680 |
+
# Mean pool the final hidden states of the patch tokens
|
| 681 |
+
patch_tokens = hidden_states[:, 1:, :]
|
| 682 |
+
pooled_output = self.layernorm(jnp.mean(patch_tokens, axis=1))
|
| 683 |
+
else:
|
| 684 |
+
# Pool by simply taking the final hidden state of the [CLS] token
|
| 685 |
+
pooled_output = hidden_states[:, 0]
|
| 686 |
+
|
| 687 |
+
return pooled_output
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
class FlaxBeitModule(nn.Module):
|
| 691 |
+
config: BeitConfig
|
| 692 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 693 |
+
add_pooling_layer: bool = True
|
| 694 |
+
|
| 695 |
+
def setup(self):
|
| 696 |
+
self.embeddings = FlaxBeitEmbeddings(self.config, dtype=self.dtype)
|
| 697 |
+
self.encoder = FlaxBeitEncoder(
|
| 698 |
+
self.config, window_size=self.embeddings.patch_embeddings.patch_shape, dtype=self.dtype
|
| 699 |
+
)
|
| 700 |
+
if not self.config.use_mean_pooling:
|
| 701 |
+
self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 702 |
+
self.pooler = FlaxBeitPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None
|
| 703 |
+
|
| 704 |
+
def __call__(
|
| 705 |
+
self,
|
| 706 |
+
pixel_values,
|
| 707 |
+
bool_masked_pos=None,
|
| 708 |
+
deterministic: bool = True,
|
| 709 |
+
output_attentions: bool = False,
|
| 710 |
+
output_hidden_states: bool = False,
|
| 711 |
+
return_dict: bool = True,
|
| 712 |
+
):
|
| 713 |
+
hidden_states = self.embeddings(pixel_values, bool_masked_pos, deterministic=deterministic)
|
| 714 |
+
|
| 715 |
+
outputs = self.encoder(
|
| 716 |
+
hidden_states,
|
| 717 |
+
deterministic=deterministic,
|
| 718 |
+
output_attentions=output_attentions,
|
| 719 |
+
output_hidden_states=output_hidden_states,
|
| 720 |
+
return_dict=return_dict,
|
| 721 |
+
)
|
| 722 |
+
hidden_states = outputs[0]
|
| 723 |
+
if not self.config.use_mean_pooling:
|
| 724 |
+
hidden_states = self.layernorm(hidden_states)
|
| 725 |
+
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
|
| 726 |
+
|
| 727 |
+
if not return_dict:
|
| 728 |
+
# if pooled is None, don't return it
|
| 729 |
+
if pooled is None:
|
| 730 |
+
return (hidden_states,) + outputs[1:]
|
| 731 |
+
return (hidden_states, pooled) + outputs[1:]
|
| 732 |
+
|
| 733 |
+
return FlaxBeitModelOutputWithPooling(
|
| 734 |
+
last_hidden_state=hidden_states,
|
| 735 |
+
pooler_output=pooled,
|
| 736 |
+
hidden_states=outputs.hidden_states,
|
| 737 |
+
attentions=outputs.attentions,
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
@add_start_docstrings(
|
| 742 |
+
"The bare Beit Model transformer outputting raw hidden-states without any specific head on top.",
|
| 743 |
+
BEIT_START_DOCSTRING,
|
| 744 |
+
)
|
| 745 |
+
class FlaxBeitModel(FlaxBeitPreTrainedModel):
|
| 746 |
+
module_class = FlaxBeitModule
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
FLAX_BEIT_MODEL_DOCSTRING = """
|
| 750 |
+
Returns:
|
| 751 |
+
|
| 752 |
+
Examples:
|
| 753 |
+
|
| 754 |
+
```python
|
| 755 |
+
>>> from transformers import AutoImageProcessor, FlaxBeitModel
|
| 756 |
+
>>> from PIL import Image
|
| 757 |
+
>>> import requests
|
| 758 |
+
|
| 759 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 760 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 761 |
+
|
| 762 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")
|
| 763 |
+
>>> model = FlaxBeitModel.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")
|
| 764 |
+
|
| 765 |
+
>>> inputs = image_processor(images=image, return_tensors="np")
|
| 766 |
+
>>> outputs = model(**inputs)
|
| 767 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 768 |
+
```
|
| 769 |
+
"""
|
| 770 |
+
|
| 771 |
+
overwrite_call_docstring(FlaxBeitModel, FLAX_BEIT_MODEL_DOCSTRING)
|
| 772 |
+
append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBeitModelOutputWithPooling, config_class=BeitConfig)
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
class FlaxBeitForMaskedImageModelingModule(nn.Module):
|
| 776 |
+
config: BeitConfig
|
| 777 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 778 |
+
|
| 779 |
+
def setup(self):
|
| 780 |
+
self.beit = FlaxBeitModule(self.config, add_pooling_layer=False, dtype=self.dtype)
|
| 781 |
+
|
| 782 |
+
# Classifier head
|
| 783 |
+
self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 784 |
+
self.lm_head = nn.Dense(
|
| 785 |
+
self.config.vocab_size,
|
| 786 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 787 |
+
dtype=self.dtype,
|
| 788 |
+
)
|
| 789 |
+
|
| 790 |
+
def __call__(
|
| 791 |
+
self,
|
| 792 |
+
pixel_values=None,
|
| 793 |
+
bool_masked_pos=None,
|
| 794 |
+
deterministic: bool = True,
|
| 795 |
+
output_attentions=None,
|
| 796 |
+
output_hidden_states=None,
|
| 797 |
+
return_dict=None,
|
| 798 |
+
):
|
| 799 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 800 |
+
|
| 801 |
+
outputs = self.beit(
|
| 802 |
+
pixel_values,
|
| 803 |
+
bool_masked_pos,
|
| 804 |
+
deterministic=deterministic,
|
| 805 |
+
output_attentions=output_attentions,
|
| 806 |
+
output_hidden_states=output_hidden_states,
|
| 807 |
+
return_dict=return_dict,
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
sequence_output = outputs[0]
|
| 811 |
+
sequence_output = self.layernorm(sequence_output)
|
| 812 |
+
prediction_scores = self.lm_head(sequence_output[:, 1:])
|
| 813 |
+
|
| 814 |
+
if not return_dict:
|
| 815 |
+
output = (prediction_scores,) + outputs[2:]
|
| 816 |
+
return output
|
| 817 |
+
|
| 818 |
+
return FlaxMaskedLMOutput(
|
| 819 |
+
logits=prediction_scores,
|
| 820 |
+
hidden_states=outputs.hidden_states,
|
| 821 |
+
attentions=outputs.attentions,
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
@add_start_docstrings(
|
| 826 |
+
"Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).",
|
| 827 |
+
BEIT_START_DOCSTRING,
|
| 828 |
+
)
|
| 829 |
+
class FlaxBeitForMaskedImageModeling(FlaxBeitPreTrainedModel):
|
| 830 |
+
module_class = FlaxBeitForMaskedImageModelingModule
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
FLAX_BEIT_MLM_DOCSTRING = """
|
| 834 |
+
bool_masked_pos (`numpy.ndarray` of shape `(batch_size, num_patches)`):
|
| 835 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
| 836 |
+
|
| 837 |
+
Returns:
|
| 838 |
+
|
| 839 |
+
Examples:
|
| 840 |
+
|
| 841 |
+
```python
|
| 842 |
+
>>> from transformers import AutoImageProcessor, BeitForMaskedImageModeling
|
| 843 |
+
>>> from PIL import Image
|
| 844 |
+
>>> import requests
|
| 845 |
+
|
| 846 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 847 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 848 |
+
|
| 849 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
|
| 850 |
+
>>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
|
| 851 |
+
|
| 852 |
+
>>> inputs = image_processor(images=image, return_tensors="np")
|
| 853 |
+
>>> outputs = model(**inputs)
|
| 854 |
+
>>> logits = outputs.logits
|
| 855 |
+
```
|
| 856 |
+
"""
|
| 857 |
+
|
| 858 |
+
overwrite_call_docstring(FlaxBeitForMaskedImageModeling, FLAX_BEIT_MLM_DOCSTRING)
|
| 859 |
+
append_replace_return_docstrings(
|
| 860 |
+
FlaxBeitForMaskedImageModeling, output_type=FlaxMaskedLMOutput, config_class=BeitConfig
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
class FlaxBeitForImageClassificationModule(nn.Module):
|
| 865 |
+
config: BeitConfig
|
| 866 |
+
dtype: jnp.dtype = jnp.float32
|
| 867 |
+
|
| 868 |
+
def setup(self):
|
| 869 |
+
self.beit = FlaxBeitModule(config=self.config, dtype=self.dtype, add_pooling_layer=True)
|
| 870 |
+
self.classifier = nn.Dense(
|
| 871 |
+
self.config.num_labels,
|
| 872 |
+
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
| 873 |
+
dtype=self.dtype,
|
| 874 |
+
)
|
| 875 |
+
|
| 876 |
+
def __call__(
|
| 877 |
+
self,
|
| 878 |
+
pixel_values=None,
|
| 879 |
+
bool_masked_pos=None,
|
| 880 |
+
deterministic: bool = True,
|
| 881 |
+
output_attentions=None,
|
| 882 |
+
output_hidden_states=None,
|
| 883 |
+
return_dict=None,
|
| 884 |
+
):
|
| 885 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 886 |
+
|
| 887 |
+
outputs = self.beit(
|
| 888 |
+
pixel_values,
|
| 889 |
+
deterministic=deterministic,
|
| 890 |
+
output_attentions=output_attentions,
|
| 891 |
+
output_hidden_states=output_hidden_states,
|
| 892 |
+
return_dict=return_dict,
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
pooled_output = outputs[1]
|
| 896 |
+
logits = self.classifier(pooled_output)
|
| 897 |
+
|
| 898 |
+
if not return_dict:
|
| 899 |
+
output = (logits,) + outputs[2:]
|
| 900 |
+
return output
|
| 901 |
+
|
| 902 |
+
return FlaxSequenceClassifierOutput(
|
| 903 |
+
logits=logits,
|
| 904 |
+
hidden_states=outputs.hidden_states,
|
| 905 |
+
attentions=outputs.attentions,
|
| 906 |
+
)
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
@add_start_docstrings(
|
| 910 |
+
"""
|
| 911 |
+
Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final
|
| 912 |
+
hidden states of the patch tokens) e.g. for ImageNet.
|
| 913 |
+
""",
|
| 914 |
+
BEIT_START_DOCSTRING,
|
| 915 |
+
)
|
| 916 |
+
class FlaxBeitForImageClassification(FlaxBeitPreTrainedModel):
|
| 917 |
+
module_class = FlaxBeitForImageClassificationModule
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
FLAX_BEIT_CLASSIF_DOCSTRING = """
|
| 921 |
+
Returns:
|
| 922 |
+
|
| 923 |
+
Example:
|
| 924 |
+
|
| 925 |
+
```python
|
| 926 |
+
>>> from transformers import AutoImageProcessor, FlaxBeitForImageClassification
|
| 927 |
+
>>> from PIL import Image
|
| 928 |
+
>>> import requests
|
| 929 |
+
|
| 930 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 931 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 932 |
+
|
| 933 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224")
|
| 934 |
+
>>> model = FlaxBeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224")
|
| 935 |
+
|
| 936 |
+
>>> inputs = image_processor(images=image, return_tensors="np")
|
| 937 |
+
>>> outputs = model(**inputs)
|
| 938 |
+
>>> logits = outputs.logits
|
| 939 |
+
>>> # model predicts one of the 1000 ImageNet classes
|
| 940 |
+
>>> predicted_class_idx = logits.argmax(-1).item()
|
| 941 |
+
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
| 942 |
+
```
|
| 943 |
+
"""
|
| 944 |
+
|
| 945 |
+
overwrite_call_docstring(FlaxBeitForImageClassification, FLAX_BEIT_CLASSIF_DOCSTRING)
|
| 946 |
+
append_replace_return_docstrings(
|
| 947 |
+
FlaxBeitForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=BeitConfig
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
__all__ = [
|
| 952 |
+
"FlaxBeitForImageClassification",
|
| 953 |
+
"FlaxBeitForMaskedImageModeling",
|
| 954 |
+
"FlaxBeitModel",
|
| 955 |
+
"FlaxBeitPreTrainedModel",
|
| 956 |
+
]
|
.venv/lib/python3.11/site-packages/transformers/models/code_llama/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .tokenization_code_llama import *
|
| 22 |
+
from .tokenization_code_llama_fast import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
.venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (786 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/tokenization_code_llama.cpython-311.pyc
ADDED
|
Binary file (22.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/tokenization_code_llama_fast.cpython-311.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 MetaAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
"""Tokenization classes for Code LLaMA."""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
from shutil import copyfile
|
| 21 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 22 |
+
|
| 23 |
+
import sentencepiece as spm
|
| 24 |
+
|
| 25 |
+
from ...convert_slow_tokenizer import import_protobuf
|
| 26 |
+
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 27 |
+
from ...utils import logging, requires_backends
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logger = logging.get_logger(__name__)
|
| 31 |
+
|
| 32 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
| 33 |
+
|
| 34 |
+
SPIECE_UNDERLINE = "▁"
|
| 35 |
+
|
| 36 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
| 37 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 38 |
+
|
| 39 |
+
# fmt: off
|
| 40 |
+
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
|
| 41 |
+
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
| 42 |
+
that your responses are socially unbiased and positive in nature.
|
| 43 |
+
|
| 44 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
| 45 |
+
correct. If you don't know the answer to a question, please don't share false information."""
|
| 46 |
+
# fmt: on
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class CodeLlamaTokenizer(PreTrainedTokenizer):
|
| 50 |
+
"""
|
| 51 |
+
Construct a CodeLlama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as
|
| 52 |
+
there is no padding token in the original model.
|
| 53 |
+
|
| 54 |
+
The default configuration match that of
|
| 55 |
+
[codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/meta-llama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json)
|
| 56 |
+
which supports prompt infilling.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
vocab_file (`str`):
|
| 60 |
+
Path to the vocabulary file.
|
| 61 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 62 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 63 |
+
token instead.
|
| 64 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 65 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 66 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 67 |
+
The end of sequence token.
|
| 68 |
+
|
| 69 |
+
<Tip>
|
| 70 |
+
|
| 71 |
+
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
| 72 |
+
The token used is the `sep_token`.
|
| 73 |
+
|
| 74 |
+
</Tip>
|
| 75 |
+
|
| 76 |
+
prefix_token (`str`, *optional*, defaults to `"▁<PRE>"`):
|
| 77 |
+
Prefix token used for infilling.
|
| 78 |
+
middle_token (`str`, *optional*, defaults to `"▁<MID>"`):
|
| 79 |
+
Middle token used for infilling.
|
| 80 |
+
suffix_token (`str`, *optional*, defaults to `"▁<SUF>"`):
|
| 81 |
+
Suffix token used for infilling.
|
| 82 |
+
eot_token (`str`, *optional*, defaults to `"▁<EOT>"`):
|
| 83 |
+
End of text token used for infilling.
|
| 84 |
+
fill_token (`str`, *optional*, defaults to `"<FILL_ME>"`):
|
| 85 |
+
The token used to split the input between the prefix and suffix.
|
| 86 |
+
suffix_first (`bool`, *optional*, defaults to `False`):
|
| 87 |
+
Whether the input prompt and suffix should be formatted with the suffix first.
|
| 88 |
+
sp_model_kwargs (`dict`, *optional*):
|
| 89 |
+
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
| 90 |
+
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
| 91 |
+
to set:
|
| 92 |
+
|
| 93 |
+
- `enable_sampling`: Enable subword regularization.
|
| 94 |
+
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
| 95 |
+
|
| 96 |
+
- `nbest_size = {0,1}`: No sampling is performed.
|
| 97 |
+
- `nbest_size > 1`: samples from the nbest_size results.
|
| 98 |
+
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
| 99 |
+
using forward-filtering-and-backward-sampling algorithm.
|
| 100 |
+
|
| 101 |
+
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
| 102 |
+
BPE-dropout.
|
| 103 |
+
add_bos_token (`bool`, *optional*, defaults to `True`):
|
| 104 |
+
Whether to add a beginning of sequence token at the start of sequences.
|
| 105 |
+
add_eos_token (`bool`, *optional*, defaults to `False`):
|
| 106 |
+
Whether to add an end of sequence token at the end of sequences.
|
| 107 |
+
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
| 108 |
+
Whether or not to clean up the tokenization spaces.
|
| 109 |
+
additional_special_tokens (`List[str]`, *optional*):
|
| 110 |
+
Additional special tokens used by the tokenizer.
|
| 111 |
+
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
|
| 112 |
+
Whether or not the default system prompt for Llama should be used.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 116 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
vocab_file,
|
| 121 |
+
unk_token="<unk>",
|
| 122 |
+
bos_token="<s>",
|
| 123 |
+
eos_token="</s>",
|
| 124 |
+
prefix_token="▁<PRE>",
|
| 125 |
+
middle_token="▁<MID>",
|
| 126 |
+
suffix_token="▁<SUF>",
|
| 127 |
+
eot_token="▁<EOT>",
|
| 128 |
+
fill_token="<FILL_ME>",
|
| 129 |
+
suffix_first=False,
|
| 130 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
| 131 |
+
add_bos_token=True,
|
| 132 |
+
add_eos_token=False,
|
| 133 |
+
clean_up_tokenization_spaces=False,
|
| 134 |
+
additional_special_tokens=None,
|
| 135 |
+
use_default_system_prompt=False,
|
| 136 |
+
**kwargs,
|
| 137 |
+
):
|
| 138 |
+
requires_backends(self, "protobuf")
|
| 139 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 140 |
+
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
|
| 141 |
+
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
|
| 142 |
+
unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
|
| 143 |
+
|
| 144 |
+
self.use_default_system_prompt = use_default_system_prompt
|
| 145 |
+
# mark tokens special to skip them
|
| 146 |
+
additional_special_tokens = additional_special_tokens or []
|
| 147 |
+
for token in [prefix_token, middle_token, suffix_token, eot_token]:
|
| 148 |
+
additional_special_tokens += [token] if token is not None else []
|
| 149 |
+
|
| 150 |
+
self.vocab_file = vocab_file
|
| 151 |
+
self.add_bos_token = add_bos_token
|
| 152 |
+
self.add_eos_token = add_eos_token
|
| 153 |
+
self._prefix_token = prefix_token
|
| 154 |
+
self._middle_token = middle_token
|
| 155 |
+
self._suffix_token = suffix_token
|
| 156 |
+
self._eot_token = eot_token
|
| 157 |
+
self.fill_token = fill_token
|
| 158 |
+
self.suffix_first = suffix_first
|
| 159 |
+
self.sp_model = self.get_spm_processor()
|
| 160 |
+
|
| 161 |
+
super().__init__(
|
| 162 |
+
bos_token=bos_token,
|
| 163 |
+
eos_token=eos_token,
|
| 164 |
+
unk_token=unk_token,
|
| 165 |
+
add_bos_token=add_bos_token,
|
| 166 |
+
add_eos_token=add_eos_token,
|
| 167 |
+
prefix_token=prefix_token,
|
| 168 |
+
middle_token=middle_token,
|
| 169 |
+
suffix_token=suffix_token,
|
| 170 |
+
eot_token=eot_token,
|
| 171 |
+
fill_token=fill_token,
|
| 172 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
| 173 |
+
suffix_first=suffix_first,
|
| 174 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 175 |
+
additional_special_tokens=additional_special_tokens,
|
| 176 |
+
use_default_system_prompt=use_default_system_prompt,
|
| 177 |
+
**kwargs,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
@property
|
| 181 |
+
def unk_token_length(self):
|
| 182 |
+
return len(self.sp_model.encode(str(self.unk_token)))
|
| 183 |
+
|
| 184 |
+
def get_spm_processor(self):
|
| 185 |
+
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 186 |
+
with open(self.vocab_file, "rb") as f:
|
| 187 |
+
sp_model = f.read()
|
| 188 |
+
model_pb2 = import_protobuf()
|
| 189 |
+
model = model_pb2.ModelProto.FromString(sp_model)
|
| 190 |
+
normalizer_spec = model_pb2.NormalizerSpec()
|
| 191 |
+
normalizer_spec.add_dummy_prefix = False
|
| 192 |
+
model.normalizer_spec.MergeFrom(normalizer_spec)
|
| 193 |
+
sp_model = model.SerializeToString()
|
| 194 |
+
tokenizer.LoadFromSerializedProto(sp_model)
|
| 195 |
+
return tokenizer
|
| 196 |
+
|
| 197 |
+
@property
|
| 198 |
+
def prefix_token(self):
|
| 199 |
+
return self._prefix_token
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def prefix_id(self):
|
| 203 |
+
if self._prefix_token is None:
|
| 204 |
+
return None
|
| 205 |
+
return self.convert_tokens_to_ids(self.prefix_token)
|
| 206 |
+
|
| 207 |
+
@property
|
| 208 |
+
def middle_token(self):
|
| 209 |
+
return self._middle_token
|
| 210 |
+
|
| 211 |
+
@property
|
| 212 |
+
def middle_id(self):
|
| 213 |
+
if self._middle_token is None:
|
| 214 |
+
return None
|
| 215 |
+
return self.convert_tokens_to_ids(self.middle_token)
|
| 216 |
+
|
| 217 |
+
@property
|
| 218 |
+
def suffix_token(self):
|
| 219 |
+
return self._suffix_token
|
| 220 |
+
|
| 221 |
+
@property
|
| 222 |
+
def suffix_id(self):
|
| 223 |
+
if self._suffix_token is None:
|
| 224 |
+
return None
|
| 225 |
+
return self.convert_tokens_to_ids(self.suffix_token)
|
| 226 |
+
|
| 227 |
+
@property
|
| 228 |
+
def eot_token(self):
|
| 229 |
+
return self._eot_token
|
| 230 |
+
|
| 231 |
+
@property
|
| 232 |
+
def eot_id(self):
|
| 233 |
+
if self._eot_token is None:
|
| 234 |
+
return None
|
| 235 |
+
return self.convert_tokens_to_ids(self.eot_token)
|
| 236 |
+
|
| 237 |
+
@property
|
| 238 |
+
def vocab_size(self):
|
| 239 |
+
"""Returns vocab size"""
|
| 240 |
+
return self.sp_model.get_piece_size()
|
| 241 |
+
|
| 242 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab
|
| 243 |
+
def get_vocab(self):
|
| 244 |
+
"""Returns vocab as a dict"""
|
| 245 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 246 |
+
vocab.update(self.added_tokens_encoder)
|
| 247 |
+
return vocab
|
| 248 |
+
|
| 249 |
+
def tokenize(self, prefix, suffix=None, suffix_first=False, **kwargs) -> List[int]:
|
| 250 |
+
# add a prefix space to `prefix`
|
| 251 |
+
if self.fill_token is not None and self.fill_token in prefix and suffix is None:
|
| 252 |
+
prefix, suffix = prefix.split(self.fill_token)
|
| 253 |
+
|
| 254 |
+
if len(prefix) > 0:
|
| 255 |
+
prefix = SPIECE_UNDERLINE + prefix.replace(SPIECE_UNDERLINE, " ")
|
| 256 |
+
|
| 257 |
+
if suffix is None or len(suffix) < 1:
|
| 258 |
+
tokens = super().tokenize(prefix, **kwargs)
|
| 259 |
+
if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
|
| 260 |
+
tokens = tokens[1:]
|
| 261 |
+
return tokens
|
| 262 |
+
|
| 263 |
+
prefix_tokens = self._tokenize(prefix) # prefix has an extra `SPIECE_UNDERLINE`
|
| 264 |
+
|
| 265 |
+
if None in (self.prefix_id, self.middle_id, self.suffix_id):
|
| 266 |
+
raise ValueError(
|
| 267 |
+
"The input either includes a `prefix` and a `suffix` used for the infilling task,"
|
| 268 |
+
f" or can be split on the {self.fill_token} token, creating a suffix and prefix,"
|
| 269 |
+
" but the model does not support `infilling`."
|
| 270 |
+
)
|
| 271 |
+
suffix_tokens = self._tokenize(suffix) # make sure CodeLlama sp model does not mess up
|
| 272 |
+
|
| 273 |
+
suffix_first = suffix_first if suffix_first is not None else self.suffix_first
|
| 274 |
+
if suffix_first:
|
| 275 |
+
# format as " <PRE> <SUF>{suf} <MID> {pre}"
|
| 276 |
+
return [self.prefix_token, self.suffix_token] + suffix_tokens + [self.middle_token] + prefix_tokens
|
| 277 |
+
else:
|
| 278 |
+
# format as " <PRE> {pre} <SUF>{suf} <MID>"
|
| 279 |
+
return [self.prefix_token] + prefix_tokens + [self.suffix_token] + suffix_tokens + [self.middle_token]
|
| 280 |
+
|
| 281 |
+
def _tokenize(self, text, **kwargs):
|
| 282 |
+
"""
|
| 283 |
+
Returns a tokenized string.
|
| 284 |
+
|
| 285 |
+
We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
|
| 286 |
+
SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
|
| 287 |
+
`['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
|
| 288 |
+
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
|
| 289 |
+
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
| 290 |
+
"""
|
| 291 |
+
tokens = self.sp_model.encode(text, out_type=str)
|
| 292 |
+
if not text.startswith((SPIECE_UNDERLINE, " ")):
|
| 293 |
+
return tokens
|
| 294 |
+
# 1. Encode string + prefix ex: "<unk> Hey"
|
| 295 |
+
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
|
| 296 |
+
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
|
| 297 |
+
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
|
| 298 |
+
|
| 299 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_token_to_id
|
| 300 |
+
def _convert_token_to_id(self, token):
|
| 301 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 302 |
+
return self.sp_model.piece_to_id(token)
|
| 303 |
+
|
| 304 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_id_to_token
|
| 305 |
+
def _convert_id_to_token(self, index):
|
| 306 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 307 |
+
token = self.sp_model.IdToPiece(index)
|
| 308 |
+
return token
|
| 309 |
+
|
| 310 |
+
def convert_tokens_to_string(self, tokens):
|
| 311 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 312 |
+
# since we manually add the prefix space, we have to remove it when decoding
|
| 313 |
+
if tokens[0].startswith(SPIECE_UNDERLINE):
|
| 314 |
+
tokens[0] = tokens[0][1:]
|
| 315 |
+
|
| 316 |
+
current_sub_tokens = []
|
| 317 |
+
out_string = ""
|
| 318 |
+
for _, token in enumerate(tokens):
|
| 319 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
| 320 |
+
if token in self.all_special_tokens:
|
| 321 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
| 322 |
+
current_sub_tokens = []
|
| 323 |
+
else:
|
| 324 |
+
current_sub_tokens.append(token)
|
| 325 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 326 |
+
return out_string
|
| 327 |
+
|
| 328 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.save_vocabulary
|
| 329 |
+
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 330 |
+
"""
|
| 331 |
+
Save the vocabulary and special tokens file to a directory.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
save_directory (`str`):
|
| 335 |
+
The directory in which to save the vocabulary.
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
`Tuple(str)`: Paths to the files saved.
|
| 339 |
+
"""
|
| 340 |
+
if not os.path.isdir(save_directory):
|
| 341 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 342 |
+
return
|
| 343 |
+
out_vocab_file = os.path.join(
|
| 344 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 348 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 349 |
+
elif not os.path.isfile(self.vocab_file):
|
| 350 |
+
with open(out_vocab_file, "wb") as fi:
|
| 351 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 352 |
+
fi.write(content_spiece_model)
|
| 353 |
+
|
| 354 |
+
return (out_vocab_file,)
|
| 355 |
+
|
| 356 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
|
| 357 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 358 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 359 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 360 |
+
|
| 361 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
| 362 |
+
|
| 363 |
+
if token_ids_1 is not None:
|
| 364 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
| 365 |
+
|
| 366 |
+
return output
|
| 367 |
+
|
| 368 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
|
| 369 |
+
def get_special_tokens_mask(
|
| 370 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 371 |
+
) -> List[int]:
|
| 372 |
+
"""
|
| 373 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 374 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
token_ids_0 (`List[int]`):
|
| 378 |
+
List of IDs.
|
| 379 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 380 |
+
Optional second list of IDs for sequence pairs.
|
| 381 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 382 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 386 |
+
"""
|
| 387 |
+
if already_has_special_tokens:
|
| 388 |
+
return super().get_special_tokens_mask(
|
| 389 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
bos_token_id = [1] if self.add_bos_token else []
|
| 393 |
+
eos_token_id = [1] if self.add_eos_token else []
|
| 394 |
+
|
| 395 |
+
if token_ids_1 is None:
|
| 396 |
+
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
| 397 |
+
return (
|
| 398 |
+
bos_token_id
|
| 399 |
+
+ ([0] * len(token_ids_0))
|
| 400 |
+
+ eos_token_id
|
| 401 |
+
+ bos_token_id
|
| 402 |
+
+ ([0] * len(token_ids_1))
|
| 403 |
+
+ eos_token_id
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
|
| 407 |
+
def create_token_type_ids_from_sequences(
|
| 408 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 409 |
+
) -> List[int]:
|
| 410 |
+
"""
|
| 411 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
|
| 412 |
+
sequence pair mask has the following format:
|
| 413 |
+
|
| 414 |
+
```
|
| 415 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 416 |
+
| first sequence | second sequence |
|
| 417 |
+
```
|
| 418 |
+
|
| 419 |
+
if token_ids_1 is None, only returns the first portion of the mask (0s).
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
token_ids_0 (`List[int]`):
|
| 423 |
+
List of ids.
|
| 424 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 425 |
+
Optional second list of IDs for sequence pairs.
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
| 429 |
+
"""
|
| 430 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 431 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 432 |
+
|
| 433 |
+
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
|
| 434 |
+
|
| 435 |
+
if token_ids_1 is not None:
|
| 436 |
+
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
| 437 |
+
|
| 438 |
+
return output
|
| 439 |
+
|
| 440 |
+
def __getstate__(self):
|
| 441 |
+
state = self.__dict__.copy()
|
| 442 |
+
state["sp_model"] = None
|
| 443 |
+
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
|
| 444 |
+
return state
|
| 445 |
+
|
| 446 |
+
def __setstate__(self, d):
|
| 447 |
+
self.__dict__ = d
|
| 448 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 449 |
+
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
__all__ = ["CodeLlamaTokenizer"]
|
.venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import os
|
| 16 |
+
from shutil import copyfile
|
| 17 |
+
from typing import List, Optional, Tuple
|
| 18 |
+
|
| 19 |
+
from tokenizers import normalizers, processors
|
| 20 |
+
|
| 21 |
+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
| 22 |
+
from ...utils import is_sentencepiece_available, logging
|
| 23 |
+
from ...utils.versions import require_version
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
require_version("tokenizers>=0.13.3")
|
| 27 |
+
|
| 28 |
+
if is_sentencepiece_available():
|
| 29 |
+
from .tokenization_code_llama import CodeLlamaTokenizer
|
| 30 |
+
else:
|
| 31 |
+
CodeLlamaTokenizer = None
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__)
|
| 34 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
|
| 35 |
+
|
| 36 |
+
SPIECE_UNDERLINE = "▁"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
| 40 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
| 41 |
+
|
| 42 |
+
# fmt: off
|
| 43 |
+
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
|
| 44 |
+
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
| 45 |
+
that your responses are socially unbiased and positive in nature.
|
| 46 |
+
|
| 47 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
| 48 |
+
correct. If you don't know the answer to a question, please don't share false information."""
|
| 49 |
+
# fmt: on
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
|
| 53 |
+
"""
|
| 54 |
+
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
|
| 55 |
+
|
| 56 |
+
This uses notably ByteFallback and no normalization.
|
| 57 |
+
|
| 58 |
+
```python
|
| 59 |
+
>>> from transformers import CodeLlamaTokenizerFast
|
| 60 |
+
|
| 61 |
+
>>> tokenizer = CodeLlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
|
| 62 |
+
>>> tokenizer.encode("Hello this is a test")
|
| 63 |
+
[1, 15043, 445, 338, 263, 1243]
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
|
| 67 |
+
call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
|
| 68 |
+
values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
|
| 69 |
+
[post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
| 73 |
+
refer to this superclass for more information regarding those methods. The default configuration match that of
|
| 74 |
+
[meta-llama/CodeLlama-7b-Instruct-hf](https://huggingface.co/meta-llama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json)
|
| 75 |
+
which supports prompt infilling.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
vocab_file (`str`, *optional*):
|
| 79 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
|
| 80 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
| 81 |
+
tokenizer_file (`str`, *optional*):
|
| 82 |
+
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
|
| 83 |
+
contains everything needed to load the tokenizer.
|
| 84 |
+
clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
|
| 85 |
+
Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
|
| 86 |
+
spaces.
|
| 87 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 88 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 89 |
+
token instead.
|
| 90 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 91 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
| 92 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 93 |
+
The end of sequence token.
|
| 94 |
+
prefix_token (`str`, *optional*, defaults to `"▁<PRE>"`):
|
| 95 |
+
Prefix token used for infilling.
|
| 96 |
+
middle_token (`str`, *optional*, defaults to `"▁<MID>"`):
|
| 97 |
+
Middle token used for infilling.
|
| 98 |
+
suffix_token (`str`, *optional*, defaults to `"▁<SUF>"`):
|
| 99 |
+
Suffix token used for infilling.
|
| 100 |
+
eot_token (`str`, *optional*, defaults to `"▁<EOT>"`):
|
| 101 |
+
End of text token used for infilling.
|
| 102 |
+
fill_token (`str`, *optional*, defaults to `"<FILL_ME>"`):
|
| 103 |
+
The token used to split the input between the prefix and suffix.
|
| 104 |
+
additional_special_tokens (`List[str]`, *optional*):
|
| 105 |
+
Additional special tokens used by the tokenizer.
|
| 106 |
+
add_bos_token (`bool`, *optional*, defaults to `True`):
|
| 107 |
+
Whether to add a beginning of sequence token at the start of sequences.
|
| 108 |
+
add_eos_token (`bool`, *optional*, defaults to `False`):
|
| 109 |
+
Whether to add an end of sequence token at the end of sequences.
|
| 110 |
+
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
|
| 111 |
+
Whether or not the default system prompt for Llama should be used.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 115 |
+
slow_tokenizer_class = CodeLlamaTokenizer
|
| 116 |
+
padding_side = "left"
|
| 117 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
vocab_file=None,
|
| 122 |
+
tokenizer_file=None,
|
| 123 |
+
clean_up_tokenization_spaces=False,
|
| 124 |
+
unk_token="<unk>",
|
| 125 |
+
bos_token="<s>",
|
| 126 |
+
eos_token="</s>",
|
| 127 |
+
prefix_token="▁<PRE>",
|
| 128 |
+
middle_token="▁<MID>",
|
| 129 |
+
suffix_token="▁<SUF>",
|
| 130 |
+
eot_token="▁<EOT>",
|
| 131 |
+
fill_token="<FILL_ME>",
|
| 132 |
+
additional_special_tokens=None,
|
| 133 |
+
add_bos_token=True,
|
| 134 |
+
add_eos_token=False,
|
| 135 |
+
use_default_system_prompt=False,
|
| 136 |
+
**kwargs,
|
| 137 |
+
):
|
| 138 |
+
# mark tokens special to skip them
|
| 139 |
+
additional_special_tokens = additional_special_tokens or []
|
| 140 |
+
for token in [prefix_token, middle_token, suffix_token, eot_token]:
|
| 141 |
+
additional_special_tokens += [token] if token is not None else []
|
| 142 |
+
self.use_default_system_prompt = use_default_system_prompt
|
| 143 |
+
|
| 144 |
+
super().__init__(
|
| 145 |
+
vocab_file=vocab_file,
|
| 146 |
+
tokenizer_file=tokenizer_file,
|
| 147 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 148 |
+
additional_special_tokens=additional_special_tokens,
|
| 149 |
+
unk_token=unk_token,
|
| 150 |
+
bos_token=bos_token,
|
| 151 |
+
eos_token=eos_token,
|
| 152 |
+
add_bos_token=add_bos_token,
|
| 153 |
+
add_eos_token=add_eos_token,
|
| 154 |
+
prefix_token=prefix_token,
|
| 155 |
+
middle_token=middle_token,
|
| 156 |
+
suffix_token=suffix_token,
|
| 157 |
+
eot_token=eot_token,
|
| 158 |
+
fill_token=fill_token,
|
| 159 |
+
use_default_system_prompt=use_default_system_prompt,
|
| 160 |
+
**kwargs,
|
| 161 |
+
)
|
| 162 |
+
self._add_bos_token = add_bos_token
|
| 163 |
+
self._add_eos_token = add_eos_token
|
| 164 |
+
self.update_post_processor()
|
| 165 |
+
|
| 166 |
+
self.vocab_file = vocab_file
|
| 167 |
+
|
| 168 |
+
self._prefix_token = prefix_token
|
| 169 |
+
self._middle_token = middle_token
|
| 170 |
+
self._suffix_token = suffix_token
|
| 171 |
+
self._eot_token = eot_token
|
| 172 |
+
self.fill_token = fill_token
|
| 173 |
+
|
| 174 |
+
@property
|
| 175 |
+
def can_save_slow_tokenizer(self) -> bool:
|
| 176 |
+
return os.path.isfile(self.vocab_file) if self.vocab_file else False
|
| 177 |
+
|
| 178 |
+
# Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.update_post_processor
|
| 179 |
+
def update_post_processor(self):
|
| 180 |
+
"""
|
| 181 |
+
Updates the underlying post processor with the current `bos_token` and `eos_token`.
|
| 182 |
+
"""
|
| 183 |
+
bos = self.bos_token
|
| 184 |
+
bos_token_id = self.bos_token_id
|
| 185 |
+
if bos is None and self.add_bos_token:
|
| 186 |
+
raise ValueError("add_bos_token = True but bos_token = None")
|
| 187 |
+
|
| 188 |
+
eos = self.eos_token
|
| 189 |
+
eos_token_id = self.eos_token_id
|
| 190 |
+
if eos is None and self.add_eos_token:
|
| 191 |
+
raise ValueError("add_eos_token = True but eos_token = None")
|
| 192 |
+
|
| 193 |
+
single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
|
| 194 |
+
pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
|
| 195 |
+
|
| 196 |
+
special_tokens = []
|
| 197 |
+
if self.add_bos_token:
|
| 198 |
+
special_tokens.append((bos, bos_token_id))
|
| 199 |
+
if self.add_eos_token:
|
| 200 |
+
special_tokens.append((eos, eos_token_id))
|
| 201 |
+
self._tokenizer.post_processor = processors.TemplateProcessing(
|
| 202 |
+
single=single, pair=pair, special_tokens=special_tokens
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
@property
|
| 206 |
+
def prefix_token(self):
|
| 207 |
+
return self._prefix_token
|
| 208 |
+
|
| 209 |
+
@property
|
| 210 |
+
def prefix_id(self):
|
| 211 |
+
if self._prefix_token is None:
|
| 212 |
+
return None
|
| 213 |
+
return self.convert_tokens_to_ids(self.prefix_token)
|
| 214 |
+
|
| 215 |
+
@property
|
| 216 |
+
def middle_token(self):
|
| 217 |
+
return self._middle_token
|
| 218 |
+
|
| 219 |
+
@property
|
| 220 |
+
def middle_id(self):
|
| 221 |
+
if self._middle_token is None:
|
| 222 |
+
return None
|
| 223 |
+
return self.convert_tokens_to_ids(self.middle_token)
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def suffix_token(self):
|
| 227 |
+
return self._suffix_token
|
| 228 |
+
|
| 229 |
+
@property
|
| 230 |
+
def suffix_id(self):
|
| 231 |
+
if self._suffix_token is None:
|
| 232 |
+
return None
|
| 233 |
+
return self.convert_tokens_to_ids(self.suffix_token)
|
| 234 |
+
|
| 235 |
+
@property
|
| 236 |
+
def eot_id(self):
|
| 237 |
+
if self._eot_token is None:
|
| 238 |
+
return None
|
| 239 |
+
return self.convert_tokens_to_ids(self.eot_token)
|
| 240 |
+
|
| 241 |
+
@property
|
| 242 |
+
def eot_token(self):
|
| 243 |
+
return self._eot_token
|
| 244 |
+
|
| 245 |
+
@property
|
| 246 |
+
def add_eos_token(self):
|
| 247 |
+
return self._add_eos_token
|
| 248 |
+
|
| 249 |
+
@property
|
| 250 |
+
def add_bos_token(self):
|
| 251 |
+
return self._add_bos_token
|
| 252 |
+
|
| 253 |
+
@add_eos_token.setter
|
| 254 |
+
def add_eos_token(self, value):
|
| 255 |
+
self._add_eos_token = value
|
| 256 |
+
self.update_post_processor()
|
| 257 |
+
|
| 258 |
+
@add_bos_token.setter
|
| 259 |
+
def add_bos_token(self, value):
|
| 260 |
+
self._add_bos_token = value
|
| 261 |
+
self.update_post_processor()
|
| 262 |
+
|
| 263 |
+
def set_infilling_processor(self, reset, suffix_first=False, add_special_tokens=True):
|
| 264 |
+
"""
|
| 265 |
+
Updates the normalizer to make sure the prompt format for `infilling` is respected. The infilling format is the
|
| 266 |
+
following: if suffix_first
|
| 267 |
+
" <PRE> <SUF>{suf} <MID> {pre}"
|
| 268 |
+
else:
|
| 269 |
+
" <PRE> {pre} <SUF>{suf} <MID>"
|
| 270 |
+
|
| 271 |
+
If `reset` is set to `True`, the `normalizer` and `post_processor` are reset to their "normal" behaviour, which
|
| 272 |
+
is to add a prefix space for the normalizer, and add a `bos_token` to the input text for the `post_processor`.
|
| 273 |
+
"""
|
| 274 |
+
if reset:
|
| 275 |
+
self._tokenizer.normalizer = normalizers.Sequence(
|
| 276 |
+
[
|
| 277 |
+
normalizers.Prepend(prepend="▁"),
|
| 278 |
+
normalizers.Replace(pattern=" ", content="▁"),
|
| 279 |
+
]
|
| 280 |
+
)
|
| 281 |
+
self.update_post_processor()
|
| 282 |
+
return
|
| 283 |
+
|
| 284 |
+
self._tokenizer.normalizer = normalizers.Replace(pattern=" ", content="▁")
|
| 285 |
+
pair = [self.bos_token] if self.add_bos_token and add_special_tokens else []
|
| 286 |
+
special_tokens = [(self.bos_token, self.bos_token_id)] if self.add_bos_token and add_special_tokens else []
|
| 287 |
+
if suffix_first:
|
| 288 |
+
# format as " <PRE> <SUF>{suf} <MID> {pre}"
|
| 289 |
+
pair += [self.prefix_token, self.suffix_token, "$B", self.middle_token, "$A"]
|
| 290 |
+
special_tokens += [
|
| 291 |
+
(self.prefix_token, self.prefix_id),
|
| 292 |
+
(self.suffix_token, self.suffix_id),
|
| 293 |
+
(self.middle_token, self.middle_id),
|
| 294 |
+
]
|
| 295 |
+
else:
|
| 296 |
+
# format as " <PRE> {pre} <SUF>{suf} <MID>"
|
| 297 |
+
pair += [self.prefix_token, "$A", self.suffix_token, "$B", self.middle_token]
|
| 298 |
+
special_tokens += [
|
| 299 |
+
(self.prefix_token, self.prefix_id),
|
| 300 |
+
(self.suffix_token, self.suffix_id),
|
| 301 |
+
(self.middle_token, self.middle_id),
|
| 302 |
+
]
|
| 303 |
+
|
| 304 |
+
if self.add_eos_token and add_special_tokens:
|
| 305 |
+
pair += [self.eos_token]
|
| 306 |
+
special_tokens += [(self.eos_token, self.eos_token_id)]
|
| 307 |
+
self._tokenizer.post_processor = processors.TemplateProcessing(
|
| 308 |
+
single="$A", pair=pair, special_tokens=special_tokens
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def encode_plus(self, text, text_pair=None, suffix_first=False, add_special_tokens=True, **kwargs):
|
| 312 |
+
# hack to make sure the input is pre-process but outside rust
|
| 313 |
+
text_pair = kwargs.pop("suffix", text_pair)
|
| 314 |
+
if self.fill_token is not None and self.fill_token in text and text_pair is None:
|
| 315 |
+
text, text_pair = text.split(self.fill_token)
|
| 316 |
+
|
| 317 |
+
if text_pair is None or len(text_pair) < 1:
|
| 318 |
+
return super().encode_plus(text, text_pair, add_special_tokens=add_special_tokens, **kwargs)
|
| 319 |
+
|
| 320 |
+
if None in (self.prefix_id, self.middle_id, self.suffix_id):
|
| 321 |
+
raise ValueError(
|
| 322 |
+
"Then input includes a `prefix` and a `suffix` used for the infilling task,"
|
| 323 |
+
" the `prefix_id, middle_id, suffix_id` must all be initialized. Current"
|
| 324 |
+
f" values : {self.prefix_id, self.middle_id, self.suffix_id}"
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
self.set_infilling_processor(False, suffix_first=suffix_first, add_special_tokens=add_special_tokens)
|
| 328 |
+
tokens = super().encode_plus(" " + text, text_pair=text_pair, add_special_tokens=True, **kwargs)
|
| 329 |
+
self.set_infilling_processor(True)
|
| 330 |
+
return tokens
|
| 331 |
+
|
| 332 |
+
# Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.save_vocabulary
|
| 333 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 334 |
+
if not self.can_save_slow_tokenizer:
|
| 335 |
+
raise ValueError(
|
| 336 |
+
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
| 337 |
+
"tokenizer."
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
if not os.path.isdir(save_directory):
|
| 341 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 342 |
+
return
|
| 343 |
+
out_vocab_file = os.path.join(
|
| 344 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
| 348 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 349 |
+
|
| 350 |
+
return (out_vocab_file,)
|
| 351 |
+
|
| 352 |
+
def build_inputs_with_special_tokens(
|
| 353 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 354 |
+
) -> List[int]:
|
| 355 |
+
"""
|
| 356 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
| 357 |
+
adding special tokens. The special tokens depend on calling set_lang.
|
| 358 |
+
|
| 359 |
+
An NLLB sequence has the following format, where `X` represents the sequence:
|
| 360 |
+
|
| 361 |
+
- `input_ids` (for encoder) `X [eos, src_lang_code]`
|
| 362 |
+
- `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
|
| 363 |
+
|
| 364 |
+
BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
|
| 365 |
+
separator.
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
token_ids_0 (`List[int]`):
|
| 369 |
+
List of IDs to which the special tokens will be added.
|
| 370 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 371 |
+
Optional second list of IDs for sequence pairs.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
`List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 375 |
+
"""
|
| 376 |
+
if token_ids_1 is None:
|
| 377 |
+
return self.bos_token_id + token_ids_0 + self.eos_token_id
|
| 378 |
+
return self.bos_token_id + token_ids_0 + token_ids_1 + self.eos_token_id
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
__all__ = ["CodeLlamaTokenizerFast"]
|
.venv/lib/python3.11/site-packages/transformers/models/convnext/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_convnext import *
|
| 22 |
+
from .feature_extraction_convnext import *
|
| 23 |
+
from .image_processing_convnext import *
|
| 24 |
+
from .modeling_convnext import *
|
| 25 |
+
from .modeling_tf_convnext import *
|
| 26 |
+
else:
|
| 27 |
+
import sys
|
| 28 |
+
|
| 29 |
+
_file = globals()["__file__"]
|
| 30 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/configuration_convnext.cpython-311.pyc
ADDED
|
Binary file (7.08 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/image_processing_convnext.cpython-311.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/convnext/configuration_convnext.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""ConvNeXT model configuration"""
|
| 16 |
+
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
from typing import Mapping
|
| 19 |
+
|
| 20 |
+
from packaging import version
|
| 21 |
+
|
| 22 |
+
from ...configuration_utils import PretrainedConfig
|
| 23 |
+
from ...onnx import OnnxConfig
|
| 24 |
+
from ...utils import logging
|
| 25 |
+
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ConvNextConfig(BackboneConfigMixin, PretrainedConfig):
|
| 32 |
+
r"""
|
| 33 |
+
This is the configuration class to store the configuration of a [`ConvNextModel`]. It is used to instantiate an
|
| 34 |
+
ConvNeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 35 |
+
with the defaults will yield a similar configuration to that of the ConvNeXT
|
| 36 |
+
[facebook/convnext-tiny-224](https://huggingface.co/facebook/convnext-tiny-224) architecture.
|
| 37 |
+
|
| 38 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 39 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 43 |
+
The number of input channels.
|
| 44 |
+
patch_size (`int`, *optional*, defaults to 4):
|
| 45 |
+
Patch size to use in the patch embedding layer.
|
| 46 |
+
num_stages (`int`, *optional*, defaults to 4):
|
| 47 |
+
The number of stages in the model.
|
| 48 |
+
hidden_sizes (`List[int]`, *optional*, defaults to [96, 192, 384, 768]):
|
| 49 |
+
Dimensionality (hidden size) at each stage.
|
| 50 |
+
depths (`List[int]`, *optional*, defaults to [3, 3, 9, 3]):
|
| 51 |
+
Depth (number of blocks) for each stage.
|
| 52 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 53 |
+
The non-linear activation function (function or string) in each block. If string, `"gelu"`, `"relu"`,
|
| 54 |
+
`"selu"` and `"gelu_new"` are supported.
|
| 55 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 56 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 57 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 58 |
+
The epsilon used by the layer normalization layers.
|
| 59 |
+
layer_scale_init_value (`float`, *optional*, defaults to 1e-6):
|
| 60 |
+
The initial value for the layer scale.
|
| 61 |
+
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
| 62 |
+
The drop rate for stochastic depth.
|
| 63 |
+
out_features (`List[str]`, *optional*):
|
| 64 |
+
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
| 65 |
+
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
|
| 66 |
+
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
|
| 67 |
+
same order as defined in the `stage_names` attribute.
|
| 68 |
+
out_indices (`List[int]`, *optional*):
|
| 69 |
+
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
|
| 70 |
+
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
|
| 71 |
+
If unset and `out_features` is unset, will default to the last stage. Must be in the
|
| 72 |
+
same order as defined in the `stage_names` attribute.
|
| 73 |
+
|
| 74 |
+
Example:
|
| 75 |
+
```python
|
| 76 |
+
>>> from transformers import ConvNextConfig, ConvNextModel
|
| 77 |
+
|
| 78 |
+
>>> # Initializing a ConvNext convnext-tiny-224 style configuration
|
| 79 |
+
>>> configuration = ConvNextConfig()
|
| 80 |
+
|
| 81 |
+
>>> # Initializing a model (with random weights) from the convnext-tiny-224 style configuration
|
| 82 |
+
>>> model = ConvNextModel(configuration)
|
| 83 |
+
|
| 84 |
+
>>> # Accessing the model configuration
|
| 85 |
+
>>> configuration = model.config
|
| 86 |
+
```"""
|
| 87 |
+
|
| 88 |
+
model_type = "convnext"
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
num_channels=3,
|
| 93 |
+
patch_size=4,
|
| 94 |
+
num_stages=4,
|
| 95 |
+
hidden_sizes=None,
|
| 96 |
+
depths=None,
|
| 97 |
+
hidden_act="gelu",
|
| 98 |
+
initializer_range=0.02,
|
| 99 |
+
layer_norm_eps=1e-12,
|
| 100 |
+
layer_scale_init_value=1e-6,
|
| 101 |
+
drop_path_rate=0.0,
|
| 102 |
+
image_size=224,
|
| 103 |
+
out_features=None,
|
| 104 |
+
out_indices=None,
|
| 105 |
+
**kwargs,
|
| 106 |
+
):
|
| 107 |
+
super().__init__(**kwargs)
|
| 108 |
+
|
| 109 |
+
self.num_channels = num_channels
|
| 110 |
+
self.patch_size = patch_size
|
| 111 |
+
self.num_stages = num_stages
|
| 112 |
+
self.hidden_sizes = [96, 192, 384, 768] if hidden_sizes is None else hidden_sizes
|
| 113 |
+
self.depths = [3, 3, 9, 3] if depths is None else depths
|
| 114 |
+
self.hidden_act = hidden_act
|
| 115 |
+
self.initializer_range = initializer_range
|
| 116 |
+
self.layer_norm_eps = layer_norm_eps
|
| 117 |
+
self.layer_scale_init_value = layer_scale_init_value
|
| 118 |
+
self.drop_path_rate = drop_path_rate
|
| 119 |
+
self.image_size = image_size
|
| 120 |
+
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
|
| 121 |
+
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
| 122 |
+
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class ConvNextOnnxConfig(OnnxConfig):
|
| 127 |
+
torch_onnx_minimum_version = version.parse("1.11")
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 131 |
+
return OrderedDict(
|
| 132 |
+
[
|
| 133 |
+
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
| 134 |
+
]
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def atol_for_validation(self) -> float:
|
| 139 |
+
return 1e-5
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
__all__ = ["ConvNextConfig", "ConvNextOnnxConfig"]
|
.venv/lib/python3.11/site-packages/transformers/models/convnext/feature_extraction_convnext.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Feature extractor class for ConvNeXT."""
|
| 16 |
+
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
from ...utils import logging
|
| 20 |
+
from .image_processing_convnext import ConvNextImageProcessor
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ConvNextFeatureExtractor(ConvNextImageProcessor):
|
| 27 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 28 |
+
warnings.warn(
|
| 29 |
+
"The class ConvNextFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
|
| 30 |
+
" Please use ConvNextImageProcessor instead.",
|
| 31 |
+
FutureWarning,
|
| 32 |
+
)
|
| 33 |
+
super().__init__(*args, **kwargs)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
__all__ = ["ConvNextFeatureExtractor"]
|
.venv/lib/python3.11/site-packages/transformers/models/convnext/image_processing_convnext.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for ConvNeXT."""
|
| 16 |
+
|
| 17 |
+
from typing import Dict, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 22 |
+
from ...image_transforms import (
|
| 23 |
+
center_crop,
|
| 24 |
+
get_resize_output_image_size,
|
| 25 |
+
resize,
|
| 26 |
+
to_channel_dimension_format,
|
| 27 |
+
)
|
| 28 |
+
from ...image_utils import (
|
| 29 |
+
IMAGENET_STANDARD_MEAN,
|
| 30 |
+
IMAGENET_STANDARD_STD,
|
| 31 |
+
ChannelDimension,
|
| 32 |
+
ImageInput,
|
| 33 |
+
PILImageResampling,
|
| 34 |
+
infer_channel_dimension_format,
|
| 35 |
+
is_scaled_image,
|
| 36 |
+
make_list_of_images,
|
| 37 |
+
to_numpy_array,
|
| 38 |
+
valid_images,
|
| 39 |
+
validate_preprocess_arguments,
|
| 40 |
+
)
|
| 41 |
+
from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if is_vision_available():
|
| 45 |
+
import PIL
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ConvNextImageProcessor(BaseImageProcessor):
|
| 52 |
+
r"""
|
| 53 |
+
Constructs a ConvNeXT image processor.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 57 |
+
Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden
|
| 58 |
+
by `do_resize` in the `preprocess` method.
|
| 59 |
+
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 384}`):
|
| 60 |
+
Resolution of the output image after `resize` is applied. If `size["shortest_edge"]` >= 384, the image is
|
| 61 |
+
resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the image will
|
| 62 |
+
be matched to `int(size["shortest_edge"]/crop_pct)`, after which the image is cropped to
|
| 63 |
+
`(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. Can
|
| 64 |
+
be overriden by `size` in the `preprocess` method.
|
| 65 |
+
crop_pct (`float` *optional*, defaults to 224 / 256):
|
| 66 |
+
Percentage of the image to crop. Only has an effect if `do_resize` is `True` and size < 384. Can be
|
| 67 |
+
overriden by `crop_pct` in the `preprocess` method.
|
| 68 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
| 69 |
+
Resampling filter to use if resizing the image. Can be overriden by `resample` in the `preprocess` method.
|
| 70 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 71 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in
|
| 72 |
+
the `preprocess` method.
|
| 73 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 74 |
+
Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
|
| 75 |
+
method.
|
| 76 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 77 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
| 78 |
+
method.
|
| 79 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
| 80 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 81 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 82 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
| 83 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 84 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
model_input_names = ["pixel_values"]
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
do_resize: bool = True,
|
| 92 |
+
size: Dict[str, int] = None,
|
| 93 |
+
crop_pct: float = None,
|
| 94 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 95 |
+
do_rescale: bool = True,
|
| 96 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 97 |
+
do_normalize: bool = True,
|
| 98 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 99 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 100 |
+
**kwargs,
|
| 101 |
+
) -> None:
|
| 102 |
+
super().__init__(**kwargs)
|
| 103 |
+
size = size if size is not None else {"shortest_edge": 384}
|
| 104 |
+
size = get_size_dict(size, default_to_square=False)
|
| 105 |
+
|
| 106 |
+
self.do_resize = do_resize
|
| 107 |
+
self.size = size
|
| 108 |
+
# Default value set here for backwards compatibility where the value in config is None
|
| 109 |
+
self.crop_pct = crop_pct if crop_pct is not None else 224 / 256
|
| 110 |
+
self.resample = resample
|
| 111 |
+
self.do_rescale = do_rescale
|
| 112 |
+
self.rescale_factor = rescale_factor
|
| 113 |
+
self.do_normalize = do_normalize
|
| 114 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
| 115 |
+
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
| 116 |
+
|
| 117 |
+
def resize(
|
| 118 |
+
self,
|
| 119 |
+
image: np.ndarray,
|
| 120 |
+
size: Dict[str, int],
|
| 121 |
+
crop_pct: float,
|
| 122 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
| 123 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 124 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 125 |
+
**kwargs,
|
| 126 |
+
) -> np.ndarray:
|
| 127 |
+
"""
|
| 128 |
+
Resize an image.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
image (`np.ndarray`):
|
| 132 |
+
Image to resize.
|
| 133 |
+
size (`Dict[str, int]`):
|
| 134 |
+
Dictionary of the form `{"shortest_edge": int}`, specifying the size of the output image. If
|
| 135 |
+
`size["shortest_edge"]` >= 384 image is resized to `(size["shortest_edge"], size["shortest_edge"])`.
|
| 136 |
+
Otherwise, the smaller edge of the image will be matched to `int(size["shortest_edge"] / crop_pct)`,
|
| 137 |
+
after which the image is cropped to `(size["shortest_edge"], size["shortest_edge"])`.
|
| 138 |
+
crop_pct (`float`):
|
| 139 |
+
Percentage of the image to crop. Only has an effect if size < 384.
|
| 140 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
| 141 |
+
Resampling filter to use when resizing the image.
|
| 142 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 143 |
+
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
| 144 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 145 |
+
The channel dimension format of the input image. If not provided, it will be inferred from the input
|
| 146 |
+
image.
|
| 147 |
+
"""
|
| 148 |
+
size = get_size_dict(size, default_to_square=False)
|
| 149 |
+
if "shortest_edge" not in size:
|
| 150 |
+
raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {size.keys()}")
|
| 151 |
+
shortest_edge = size["shortest_edge"]
|
| 152 |
+
|
| 153 |
+
if shortest_edge < 384:
|
| 154 |
+
# maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
|
| 155 |
+
resize_shortest_edge = int(shortest_edge / crop_pct)
|
| 156 |
+
resize_size = get_resize_output_image_size(
|
| 157 |
+
image, size=resize_shortest_edge, default_to_square=False, input_data_format=input_data_format
|
| 158 |
+
)
|
| 159 |
+
image = resize(
|
| 160 |
+
image=image,
|
| 161 |
+
size=resize_size,
|
| 162 |
+
resample=resample,
|
| 163 |
+
data_format=data_format,
|
| 164 |
+
input_data_format=input_data_format,
|
| 165 |
+
**kwargs,
|
| 166 |
+
)
|
| 167 |
+
# then crop to (shortest_edge, shortest_edge)
|
| 168 |
+
return center_crop(
|
| 169 |
+
image=image,
|
| 170 |
+
size=(shortest_edge, shortest_edge),
|
| 171 |
+
data_format=data_format,
|
| 172 |
+
input_data_format=input_data_format,
|
| 173 |
+
**kwargs,
|
| 174 |
+
)
|
| 175 |
+
else:
|
| 176 |
+
# warping (no cropping) when evaluated at 384 or larger
|
| 177 |
+
return resize(
|
| 178 |
+
image,
|
| 179 |
+
size=(shortest_edge, shortest_edge),
|
| 180 |
+
resample=resample,
|
| 181 |
+
data_format=data_format,
|
| 182 |
+
input_data_format=input_data_format,
|
| 183 |
+
**kwargs,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
@filter_out_non_signature_kwargs()
|
| 187 |
+
def preprocess(
|
| 188 |
+
self,
|
| 189 |
+
images: ImageInput,
|
| 190 |
+
do_resize: bool = None,
|
| 191 |
+
size: Dict[str, int] = None,
|
| 192 |
+
crop_pct: float = None,
|
| 193 |
+
resample: PILImageResampling = None,
|
| 194 |
+
do_rescale: bool = None,
|
| 195 |
+
rescale_factor: float = None,
|
| 196 |
+
do_normalize: bool = None,
|
| 197 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 198 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 199 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 200 |
+
data_format: ChannelDimension = ChannelDimension.FIRST,
|
| 201 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 202 |
+
) -> PIL.Image.Image:
|
| 203 |
+
"""
|
| 204 |
+
Preprocess an image or batch of images.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
images (`ImageInput`):
|
| 208 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 209 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 210 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 211 |
+
Whether to resize the image.
|
| 212 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 213 |
+
Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image
|
| 214 |
+
is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the
|
| 215 |
+
image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to
|
| 216 |
+
`(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`.
|
| 217 |
+
crop_pct (`float`, *optional*, defaults to `self.crop_pct`):
|
| 218 |
+
Percentage of the image to crop if size < 384.
|
| 219 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
| 220 |
+
Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only
|
| 221 |
+
has an effect if `do_resize` is set to `True`.
|
| 222 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 223 |
+
Whether to rescale the image values between [0 - 1].
|
| 224 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 225 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 226 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 227 |
+
Whether to normalize the image.
|
| 228 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 229 |
+
Image mean.
|
| 230 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 231 |
+
Image standard deviation.
|
| 232 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 233 |
+
The type of tensors to return. Can be one of:
|
| 234 |
+
- Unset: Return a list of `np.ndarray`.
|
| 235 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 236 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 237 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 238 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 239 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 240 |
+
The channel dimension format for the output image. Can be one of:
|
| 241 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 242 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 243 |
+
- Unset: Use the channel dimension format of the input image.
|
| 244 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 245 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 246 |
+
from the input image. Can be one of:
|
| 247 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 248 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 249 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 250 |
+
"""
|
| 251 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 252 |
+
crop_pct = crop_pct if crop_pct is not None else self.crop_pct
|
| 253 |
+
resample = resample if resample is not None else self.resample
|
| 254 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 255 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 256 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 257 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 258 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 259 |
+
|
| 260 |
+
size = size if size is not None else self.size
|
| 261 |
+
size = get_size_dict(size, default_to_square=False)
|
| 262 |
+
|
| 263 |
+
images = make_list_of_images(images)
|
| 264 |
+
|
| 265 |
+
if not valid_images(images):
|
| 266 |
+
raise ValueError(
|
| 267 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 268 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
validate_preprocess_arguments(
|
| 272 |
+
do_rescale=do_rescale,
|
| 273 |
+
rescale_factor=rescale_factor,
|
| 274 |
+
do_normalize=do_normalize,
|
| 275 |
+
image_mean=image_mean,
|
| 276 |
+
image_std=image_std,
|
| 277 |
+
do_resize=do_resize,
|
| 278 |
+
size=size,
|
| 279 |
+
resample=resample,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# All transformations expect numpy arrays.
|
| 283 |
+
images = [to_numpy_array(image) for image in images]
|
| 284 |
+
|
| 285 |
+
if do_rescale and is_scaled_image(images[0]):
|
| 286 |
+
logger.warning_once(
|
| 287 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 288 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
if input_data_format is None:
|
| 292 |
+
# We assume that all images have the same channel dimension format.
|
| 293 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 294 |
+
|
| 295 |
+
if do_resize:
|
| 296 |
+
images = [
|
| 297 |
+
self.resize(
|
| 298 |
+
image=image, size=size, crop_pct=crop_pct, resample=resample, input_data_format=input_data_format
|
| 299 |
+
)
|
| 300 |
+
for image in images
|
| 301 |
+
]
|
| 302 |
+
|
| 303 |
+
if do_rescale:
|
| 304 |
+
images = [
|
| 305 |
+
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 306 |
+
for image in images
|
| 307 |
+
]
|
| 308 |
+
|
| 309 |
+
if do_normalize:
|
| 310 |
+
images = [
|
| 311 |
+
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 312 |
+
for image in images
|
| 313 |
+
]
|
| 314 |
+
|
| 315 |
+
images = [
|
| 316 |
+
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
| 317 |
+
]
|
| 318 |
+
|
| 319 |
+
data = {"pixel_values": images}
|
| 320 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
__all__ = ["ConvNextImageProcessor"]
|
.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_convnext.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch ConvNext model."""
|
| 16 |
+
|
| 17 |
+
from typing import Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.utils.checkpoint
|
| 21 |
+
from torch import nn
|
| 22 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 23 |
+
|
| 24 |
+
from ...activations import ACT2FN
|
| 25 |
+
from ...modeling_outputs import (
|
| 26 |
+
BackboneOutput,
|
| 27 |
+
BaseModelOutputWithNoAttention,
|
| 28 |
+
BaseModelOutputWithPoolingAndNoAttention,
|
| 29 |
+
ImageClassifierOutputWithNoAttention,
|
| 30 |
+
)
|
| 31 |
+
from ...modeling_utils import PreTrainedModel
|
| 32 |
+
from ...utils import (
|
| 33 |
+
add_code_sample_docstrings,
|
| 34 |
+
add_start_docstrings,
|
| 35 |
+
add_start_docstrings_to_model_forward,
|
| 36 |
+
logging,
|
| 37 |
+
replace_return_docstrings,
|
| 38 |
+
)
|
| 39 |
+
from ...utils.backbone_utils import BackboneMixin
|
| 40 |
+
from .configuration_convnext import ConvNextConfig
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
# General docstring
|
| 46 |
+
_CONFIG_FOR_DOC = "ConvNextConfig"
|
| 47 |
+
|
| 48 |
+
# Base docstring
|
| 49 |
+
_CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224"
|
| 50 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 768, 7, 7]
|
| 51 |
+
|
| 52 |
+
# Image classification docstring
|
| 53 |
+
_IMAGE_CLASS_CHECKPOINT = "facebook/convnext-tiny-224"
|
| 54 |
+
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Copied from transformers.models.beit.modeling_beit.drop_path
|
| 58 |
+
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
| 59 |
+
"""
|
| 60 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 61 |
+
|
| 62 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
| 63 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 64 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
| 65 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
| 66 |
+
argument.
|
| 67 |
+
"""
|
| 68 |
+
if drop_prob == 0.0 or not training:
|
| 69 |
+
return input
|
| 70 |
+
keep_prob = 1 - drop_prob
|
| 71 |
+
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 72 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
| 73 |
+
random_tensor.floor_() # binarize
|
| 74 |
+
output = input.div(keep_prob) * random_tensor
|
| 75 |
+
return output
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNext
|
| 79 |
+
class ConvNextDropPath(nn.Module):
|
| 80 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 81 |
+
|
| 82 |
+
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.drop_prob = drop_prob
|
| 85 |
+
|
| 86 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
return drop_path(hidden_states, self.drop_prob, self.training)
|
| 88 |
+
|
| 89 |
+
def extra_repr(self) -> str:
|
| 90 |
+
return "p={}".format(self.drop_prob)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class ConvNextLayerNorm(nn.Module):
|
| 94 |
+
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 95 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
|
| 96 |
+
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 102 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 103 |
+
self.eps = eps
|
| 104 |
+
self.data_format = data_format
|
| 105 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 106 |
+
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
|
| 107 |
+
self.normalized_shape = (normalized_shape,)
|
| 108 |
+
|
| 109 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 110 |
+
if self.data_format == "channels_last":
|
| 111 |
+
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 112 |
+
elif self.data_format == "channels_first":
|
| 113 |
+
input_dtype = x.dtype
|
| 114 |
+
x = x.float()
|
| 115 |
+
u = x.mean(1, keepdim=True)
|
| 116 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 117 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 118 |
+
x = x.to(dtype=input_dtype)
|
| 119 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class ConvNextEmbeddings(nn.Module):
|
| 124 |
+
"""This class is comparable to (and inspired by) the SwinEmbeddings class
|
| 125 |
+
found in src/transformers/models/swin/modeling_swin.py.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(self, config):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.patch_embeddings = nn.Conv2d(
|
| 131 |
+
config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
|
| 132 |
+
)
|
| 133 |
+
self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
|
| 134 |
+
self.num_channels = config.num_channels
|
| 135 |
+
|
| 136 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
| 137 |
+
num_channels = pixel_values.shape[1]
|
| 138 |
+
if num_channels != self.num_channels:
|
| 139 |
+
raise ValueError(
|
| 140 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 141 |
+
)
|
| 142 |
+
embeddings = self.patch_embeddings(pixel_values)
|
| 143 |
+
embeddings = self.layernorm(embeddings)
|
| 144 |
+
return embeddings
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class ConvNextLayer(nn.Module):
|
| 148 |
+
"""This corresponds to the `Block` class in the original implementation.
|
| 149 |
+
|
| 150 |
+
There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
|
| 151 |
+
H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
|
| 152 |
+
|
| 153 |
+
The authors used (2) as they find it slightly faster in PyTorch.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
config ([`ConvNextConfig`]): Model configuration class.
|
| 157 |
+
dim (`int`): Number of input channels.
|
| 158 |
+
drop_path (`float`): Stochastic depth rate. Default: 0.0.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def __init__(self, config, dim, drop_path=0):
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
| 164 |
+
self.layernorm = ConvNextLayerNorm(dim, eps=1e-6)
|
| 165 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
| 166 |
+
self.act = ACT2FN[config.hidden_act]
|
| 167 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 168 |
+
self.layer_scale_parameter = (
|
| 169 |
+
nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| 170 |
+
if config.layer_scale_init_value > 0
|
| 171 |
+
else None
|
| 172 |
+
)
|
| 173 |
+
self.drop_path = ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 174 |
+
|
| 175 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
|
| 176 |
+
input = hidden_states
|
| 177 |
+
x = self.dwconv(hidden_states)
|
| 178 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 179 |
+
x = self.layernorm(x)
|
| 180 |
+
x = self.pwconv1(x)
|
| 181 |
+
x = self.act(x)
|
| 182 |
+
x = self.pwconv2(x)
|
| 183 |
+
if self.layer_scale_parameter is not None:
|
| 184 |
+
x = self.layer_scale_parameter * x
|
| 185 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 186 |
+
|
| 187 |
+
x = input + self.drop_path(x)
|
| 188 |
+
return x
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class ConvNextStage(nn.Module):
|
| 192 |
+
"""ConvNeXT stage, consisting of an optional downsampling layer + multiple residual blocks.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
config ([`ConvNextConfig`]): Model configuration class.
|
| 196 |
+
in_channels (`int`): Number of input channels.
|
| 197 |
+
out_channels (`int`): Number of output channels.
|
| 198 |
+
depth (`int`): Number of residual blocks.
|
| 199 |
+
drop_path_rates(`List[float]`): Stochastic depth rates for each layer.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None):
|
| 203 |
+
super().__init__()
|
| 204 |
+
|
| 205 |
+
if in_channels != out_channels or stride > 1:
|
| 206 |
+
self.downsampling_layer = nn.Sequential(
|
| 207 |
+
ConvNextLayerNorm(in_channels, eps=1e-6, data_format="channels_first"),
|
| 208 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
|
| 209 |
+
)
|
| 210 |
+
else:
|
| 211 |
+
self.downsampling_layer = nn.Identity()
|
| 212 |
+
drop_path_rates = drop_path_rates or [0.0] * depth
|
| 213 |
+
self.layers = nn.Sequential(
|
| 214 |
+
*[ConvNextLayer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
|
| 218 |
+
hidden_states = self.downsampling_layer(hidden_states)
|
| 219 |
+
hidden_states = self.layers(hidden_states)
|
| 220 |
+
return hidden_states
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class ConvNextEncoder(nn.Module):
|
| 224 |
+
def __init__(self, config):
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.stages = nn.ModuleList()
|
| 227 |
+
drop_path_rates = [
|
| 228 |
+
x.tolist() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths)).split(config.depths)
|
| 229 |
+
]
|
| 230 |
+
prev_chs = config.hidden_sizes[0]
|
| 231 |
+
for i in range(config.num_stages):
|
| 232 |
+
out_chs = config.hidden_sizes[i]
|
| 233 |
+
stage = ConvNextStage(
|
| 234 |
+
config,
|
| 235 |
+
in_channels=prev_chs,
|
| 236 |
+
out_channels=out_chs,
|
| 237 |
+
stride=2 if i > 0 else 1,
|
| 238 |
+
depth=config.depths[i],
|
| 239 |
+
drop_path_rates=drop_path_rates[i],
|
| 240 |
+
)
|
| 241 |
+
self.stages.append(stage)
|
| 242 |
+
prev_chs = out_chs
|
| 243 |
+
|
| 244 |
+
def forward(
|
| 245 |
+
self,
|
| 246 |
+
hidden_states: torch.FloatTensor,
|
| 247 |
+
output_hidden_states: Optional[bool] = False,
|
| 248 |
+
return_dict: Optional[bool] = True,
|
| 249 |
+
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
|
| 250 |
+
all_hidden_states = () if output_hidden_states else None
|
| 251 |
+
|
| 252 |
+
for i, layer_module in enumerate(self.stages):
|
| 253 |
+
if output_hidden_states:
|
| 254 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 255 |
+
|
| 256 |
+
hidden_states = layer_module(hidden_states)
|
| 257 |
+
|
| 258 |
+
if output_hidden_states:
|
| 259 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 260 |
+
|
| 261 |
+
if not return_dict:
|
| 262 |
+
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
|
| 263 |
+
|
| 264 |
+
return BaseModelOutputWithNoAttention(
|
| 265 |
+
last_hidden_state=hidden_states,
|
| 266 |
+
hidden_states=all_hidden_states,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class ConvNextPreTrainedModel(PreTrainedModel):
|
| 271 |
+
"""
|
| 272 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 273 |
+
models.
|
| 274 |
+
"""
|
| 275 |
+
|
| 276 |
+
config_class = ConvNextConfig
|
| 277 |
+
base_model_prefix = "convnext"
|
| 278 |
+
main_input_name = "pixel_values"
|
| 279 |
+
_no_split_modules = ["ConvNextLayer"]
|
| 280 |
+
|
| 281 |
+
def _init_weights(self, module):
|
| 282 |
+
"""Initialize the weights"""
|
| 283 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 284 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 285 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 286 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 287 |
+
if module.bias is not None:
|
| 288 |
+
module.bias.data.zero_()
|
| 289 |
+
elif isinstance(module, nn.LayerNorm):
|
| 290 |
+
module.bias.data.zero_()
|
| 291 |
+
module.weight.data.fill_(1.0)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
CONVNEXT_START_DOCSTRING = r"""
|
| 295 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 296 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 297 |
+
behavior.
|
| 298 |
+
|
| 299 |
+
Parameters:
|
| 300 |
+
config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model.
|
| 301 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 302 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
CONVNEXT_INPUTS_DOCSTRING = r"""
|
| 306 |
+
Args:
|
| 307 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 308 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
| 309 |
+
[`ConvNextImageProcessor.__call__`] for details.
|
| 310 |
+
|
| 311 |
+
output_hidden_states (`bool`, *optional*):
|
| 312 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 313 |
+
more detail.
|
| 314 |
+
return_dict (`bool`, *optional*):
|
| 315 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
@add_start_docstrings(
|
| 320 |
+
"The bare ConvNext model outputting raw features without any specific head on top.",
|
| 321 |
+
CONVNEXT_START_DOCSTRING,
|
| 322 |
+
)
|
| 323 |
+
class ConvNextModel(ConvNextPreTrainedModel):
|
| 324 |
+
def __init__(self, config):
|
| 325 |
+
super().__init__(config)
|
| 326 |
+
self.config = config
|
| 327 |
+
|
| 328 |
+
self.embeddings = ConvNextEmbeddings(config)
|
| 329 |
+
self.encoder = ConvNextEncoder(config)
|
| 330 |
+
|
| 331 |
+
# final layernorm layer
|
| 332 |
+
self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
|
| 333 |
+
|
| 334 |
+
# Initialize weights and apply final processing
|
| 335 |
+
self.post_init()
|
| 336 |
+
|
| 337 |
+
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
|
| 338 |
+
@add_code_sample_docstrings(
|
| 339 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 340 |
+
output_type=BaseModelOutputWithPoolingAndNoAttention,
|
| 341 |
+
config_class=_CONFIG_FOR_DOC,
|
| 342 |
+
modality="vision",
|
| 343 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 344 |
+
)
|
| 345 |
+
def forward(
|
| 346 |
+
self,
|
| 347 |
+
pixel_values: torch.FloatTensor = None,
|
| 348 |
+
output_hidden_states: Optional[bool] = None,
|
| 349 |
+
return_dict: Optional[bool] = None,
|
| 350 |
+
) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
|
| 351 |
+
output_hidden_states = (
|
| 352 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 353 |
+
)
|
| 354 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 355 |
+
|
| 356 |
+
if pixel_values is None:
|
| 357 |
+
raise ValueError("You have to specify pixel_values")
|
| 358 |
+
|
| 359 |
+
embedding_output = self.embeddings(pixel_values)
|
| 360 |
+
|
| 361 |
+
encoder_outputs = self.encoder(
|
| 362 |
+
embedding_output,
|
| 363 |
+
output_hidden_states=output_hidden_states,
|
| 364 |
+
return_dict=return_dict,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
last_hidden_state = encoder_outputs[0]
|
| 368 |
+
|
| 369 |
+
# global average pooling, (N, C, H, W) -> (N, C)
|
| 370 |
+
pooled_output = self.layernorm(last_hidden_state.mean([-2, -1]))
|
| 371 |
+
|
| 372 |
+
if not return_dict:
|
| 373 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
| 374 |
+
|
| 375 |
+
return BaseModelOutputWithPoolingAndNoAttention(
|
| 376 |
+
last_hidden_state=last_hidden_state,
|
| 377 |
+
pooler_output=pooled_output,
|
| 378 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
@add_start_docstrings(
|
| 383 |
+
"""
|
| 384 |
+
ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
|
| 385 |
+
ImageNet.
|
| 386 |
+
""",
|
| 387 |
+
CONVNEXT_START_DOCSTRING,
|
| 388 |
+
)
|
| 389 |
+
class ConvNextForImageClassification(ConvNextPreTrainedModel):
|
| 390 |
+
def __init__(self, config):
|
| 391 |
+
super().__init__(config)
|
| 392 |
+
|
| 393 |
+
self.num_labels = config.num_labels
|
| 394 |
+
self.convnext = ConvNextModel(config)
|
| 395 |
+
|
| 396 |
+
# Classifier head
|
| 397 |
+
self.classifier = (
|
| 398 |
+
nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
# Initialize weights and apply final processing
|
| 402 |
+
self.post_init()
|
| 403 |
+
|
| 404 |
+
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
|
| 405 |
+
@add_code_sample_docstrings(
|
| 406 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| 407 |
+
output_type=ImageClassifierOutputWithNoAttention,
|
| 408 |
+
config_class=_CONFIG_FOR_DOC,
|
| 409 |
+
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
| 410 |
+
)
|
| 411 |
+
def forward(
|
| 412 |
+
self,
|
| 413 |
+
pixel_values: torch.FloatTensor = None,
|
| 414 |
+
labels: Optional[torch.LongTensor] = None,
|
| 415 |
+
output_hidden_states: Optional[bool] = None,
|
| 416 |
+
return_dict: Optional[bool] = None,
|
| 417 |
+
) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
|
| 418 |
+
r"""
|
| 419 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 420 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 421 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 422 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 423 |
+
"""
|
| 424 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 425 |
+
|
| 426 |
+
outputs = self.convnext(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
|
| 427 |
+
|
| 428 |
+
pooled_output = outputs.pooler_output if return_dict else outputs[1]
|
| 429 |
+
|
| 430 |
+
logits = self.classifier(pooled_output)
|
| 431 |
+
|
| 432 |
+
loss = None
|
| 433 |
+
if labels is not None:
|
| 434 |
+
if self.config.problem_type is None:
|
| 435 |
+
if self.num_labels == 1:
|
| 436 |
+
self.config.problem_type = "regression"
|
| 437 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 438 |
+
self.config.problem_type = "single_label_classification"
|
| 439 |
+
else:
|
| 440 |
+
self.config.problem_type = "multi_label_classification"
|
| 441 |
+
|
| 442 |
+
if self.config.problem_type == "regression":
|
| 443 |
+
loss_fct = MSELoss()
|
| 444 |
+
if self.num_labels == 1:
|
| 445 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 446 |
+
else:
|
| 447 |
+
loss = loss_fct(logits, labels)
|
| 448 |
+
elif self.config.problem_type == "single_label_classification":
|
| 449 |
+
loss_fct = CrossEntropyLoss()
|
| 450 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 451 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 452 |
+
loss_fct = BCEWithLogitsLoss()
|
| 453 |
+
loss = loss_fct(logits, labels)
|
| 454 |
+
if not return_dict:
|
| 455 |
+
output = (logits,) + outputs[2:]
|
| 456 |
+
return ((loss,) + output) if loss is not None else output
|
| 457 |
+
|
| 458 |
+
return ImageClassifierOutputWithNoAttention(
|
| 459 |
+
loss=loss,
|
| 460 |
+
logits=logits,
|
| 461 |
+
hidden_states=outputs.hidden_states,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
@add_start_docstrings(
|
| 466 |
+
"""
|
| 467 |
+
ConvNeXt backbone, to be used with frameworks like DETR and MaskFormer.
|
| 468 |
+
""",
|
| 469 |
+
CONVNEXT_START_DOCSTRING,
|
| 470 |
+
)
|
| 471 |
+
class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
|
| 472 |
+
def __init__(self, config):
|
| 473 |
+
super().__init__(config)
|
| 474 |
+
super()._init_backbone(config)
|
| 475 |
+
|
| 476 |
+
self.embeddings = ConvNextEmbeddings(config)
|
| 477 |
+
self.encoder = ConvNextEncoder(config)
|
| 478 |
+
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
|
| 479 |
+
|
| 480 |
+
# Add layer norms to hidden states of out_features
|
| 481 |
+
hidden_states_norms = {}
|
| 482 |
+
for stage, num_channels in zip(self._out_features, self.channels):
|
| 483 |
+
hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first")
|
| 484 |
+
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
|
| 485 |
+
|
| 486 |
+
# initialize weights and apply final processing
|
| 487 |
+
self.post_init()
|
| 488 |
+
|
| 489 |
+
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
|
| 490 |
+
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
| 491 |
+
def forward(
|
| 492 |
+
self,
|
| 493 |
+
pixel_values: torch.Tensor,
|
| 494 |
+
output_hidden_states: Optional[bool] = None,
|
| 495 |
+
return_dict: Optional[bool] = None,
|
| 496 |
+
) -> BackboneOutput:
|
| 497 |
+
"""
|
| 498 |
+
Returns:
|
| 499 |
+
|
| 500 |
+
Examples:
|
| 501 |
+
|
| 502 |
+
```python
|
| 503 |
+
>>> from transformers import AutoImageProcessor, AutoBackbone
|
| 504 |
+
>>> import torch
|
| 505 |
+
>>> from PIL import Image
|
| 506 |
+
>>> import requests
|
| 507 |
+
|
| 508 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 509 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 510 |
+
|
| 511 |
+
>>> processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
|
| 512 |
+
>>> model = AutoBackbone.from_pretrained("facebook/convnext-tiny-224")
|
| 513 |
+
|
| 514 |
+
>>> inputs = processor(image, return_tensors="pt")
|
| 515 |
+
>>> outputs = model(**inputs)
|
| 516 |
+
```"""
|
| 517 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 518 |
+
output_hidden_states = (
|
| 519 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
embedding_output = self.embeddings(pixel_values)
|
| 523 |
+
|
| 524 |
+
outputs = self.encoder(
|
| 525 |
+
embedding_output,
|
| 526 |
+
output_hidden_states=True,
|
| 527 |
+
return_dict=return_dict,
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
| 531 |
+
|
| 532 |
+
feature_maps = ()
|
| 533 |
+
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
| 534 |
+
if stage in self.out_features:
|
| 535 |
+
hidden_state = self.hidden_states_norms[stage](hidden_state)
|
| 536 |
+
feature_maps += (hidden_state,)
|
| 537 |
+
|
| 538 |
+
if not return_dict:
|
| 539 |
+
output = (feature_maps,)
|
| 540 |
+
if output_hidden_states:
|
| 541 |
+
output += (hidden_states,)
|
| 542 |
+
return output
|
| 543 |
+
|
| 544 |
+
return BackboneOutput(
|
| 545 |
+
feature_maps=feature_maps,
|
| 546 |
+
hidden_states=hidden_states if output_hidden_states else None,
|
| 547 |
+
attentions=None,
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
__all__ = ["ConvNextForImageClassification", "ConvNextModel", "ConvNextPreTrainedModel", "ConvNextBackbone"]
|
.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_tf_convnext.py
ADDED
|
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Meta Platforms Inc. and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""TF 2.0 ConvNext model."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from typing import List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import tensorflow as tf
|
| 23 |
+
|
| 24 |
+
from ...activations_tf import get_tf_activation
|
| 25 |
+
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
|
| 26 |
+
from ...modeling_tf_utils import (
|
| 27 |
+
TFModelInputType,
|
| 28 |
+
TFPreTrainedModel,
|
| 29 |
+
TFSequenceClassificationLoss,
|
| 30 |
+
get_initializer,
|
| 31 |
+
keras,
|
| 32 |
+
keras_serializable,
|
| 33 |
+
unpack_inputs,
|
| 34 |
+
)
|
| 35 |
+
from ...tf_utils import shape_list
|
| 36 |
+
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
| 37 |
+
from .configuration_convnext import ConvNextConfig
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
_CONFIG_FOR_DOC = "ConvNextConfig"
|
| 44 |
+
_CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TFConvNextDropPath(keras.layers.Layer):
|
| 48 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 49 |
+
References:
|
| 50 |
+
(1) github.com:rwightman/pytorch-image-models
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, drop_path: float, **kwargs):
|
| 54 |
+
super().__init__(**kwargs)
|
| 55 |
+
self.drop_path = drop_path
|
| 56 |
+
|
| 57 |
+
def call(self, x: tf.Tensor, training=None):
|
| 58 |
+
if training:
|
| 59 |
+
keep_prob = 1 - self.drop_path
|
| 60 |
+
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
|
| 61 |
+
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
|
| 62 |
+
random_tensor = tf.floor(random_tensor)
|
| 63 |
+
return (x / keep_prob) * random_tensor
|
| 64 |
+
return x
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TFConvNextEmbeddings(keras.layers.Layer):
|
| 68 |
+
"""This class is comparable to (and inspired by) the SwinEmbeddings class
|
| 69 |
+
found in src/transformers/models/swin/modeling_swin.py.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, config: ConvNextConfig, **kwargs):
|
| 73 |
+
super().__init__(**kwargs)
|
| 74 |
+
self.patch_embeddings = keras.layers.Conv2D(
|
| 75 |
+
filters=config.hidden_sizes[0],
|
| 76 |
+
kernel_size=config.patch_size,
|
| 77 |
+
strides=config.patch_size,
|
| 78 |
+
name="patch_embeddings",
|
| 79 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 80 |
+
bias_initializer=keras.initializers.Zeros(),
|
| 81 |
+
)
|
| 82 |
+
self.layernorm = keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
|
| 83 |
+
self.num_channels = config.num_channels
|
| 84 |
+
self.config = config
|
| 85 |
+
|
| 86 |
+
def call(self, pixel_values):
|
| 87 |
+
if isinstance(pixel_values, dict):
|
| 88 |
+
pixel_values = pixel_values["pixel_values"]
|
| 89 |
+
|
| 90 |
+
tf.debugging.assert_equal(
|
| 91 |
+
shape_list(pixel_values)[1],
|
| 92 |
+
self.num_channels,
|
| 93 |
+
message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.",
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
|
| 97 |
+
# So change the input format from `NCHW` to `NHWC`.
|
| 98 |
+
# shape = (batch_size, in_height, in_width, in_channels)
|
| 99 |
+
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
|
| 100 |
+
|
| 101 |
+
embeddings = self.patch_embeddings(pixel_values)
|
| 102 |
+
embeddings = self.layernorm(embeddings)
|
| 103 |
+
return embeddings
|
| 104 |
+
|
| 105 |
+
def build(self, input_shape=None):
|
| 106 |
+
if self.built:
|
| 107 |
+
return
|
| 108 |
+
self.built = True
|
| 109 |
+
if getattr(self, "patch_embeddings", None) is not None:
|
| 110 |
+
with tf.name_scope(self.patch_embeddings.name):
|
| 111 |
+
self.patch_embeddings.build([None, None, None, self.config.num_channels])
|
| 112 |
+
if getattr(self, "layernorm", None) is not None:
|
| 113 |
+
with tf.name_scope(self.layernorm.name):
|
| 114 |
+
self.layernorm.build([None, None, None, self.config.hidden_sizes[0]])
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class TFConvNextLayer(keras.layers.Layer):
|
| 118 |
+
"""This corresponds to the `Block` class in the original implementation.
|
| 119 |
+
|
| 120 |
+
There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
|
| 121 |
+
H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
|
| 122 |
+
|
| 123 |
+
The authors used (2) as they find it slightly faster in PyTorch. Since we already permuted the inputs to follow
|
| 124 |
+
NHWC ordering, we can just apply the operations straight-away without the permutation.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
config ([`ConvNextConfig`]): Model configuration class.
|
| 128 |
+
dim (`int`): Number of input channels.
|
| 129 |
+
drop_path (`float`): Stochastic depth rate. Default: 0.0.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(self, config, dim, drop_path=0.0, **kwargs):
|
| 133 |
+
super().__init__(**kwargs)
|
| 134 |
+
self.dim = dim
|
| 135 |
+
self.config = config
|
| 136 |
+
self.dwconv = keras.layers.Conv2D(
|
| 137 |
+
filters=dim,
|
| 138 |
+
kernel_size=7,
|
| 139 |
+
padding="same",
|
| 140 |
+
groups=dim,
|
| 141 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 142 |
+
bias_initializer="zeros",
|
| 143 |
+
name="dwconv",
|
| 144 |
+
) # depthwise conv
|
| 145 |
+
self.layernorm = keras.layers.LayerNormalization(
|
| 146 |
+
epsilon=1e-6,
|
| 147 |
+
name="layernorm",
|
| 148 |
+
)
|
| 149 |
+
self.pwconv1 = keras.layers.Dense(
|
| 150 |
+
units=4 * dim,
|
| 151 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 152 |
+
bias_initializer="zeros",
|
| 153 |
+
name="pwconv1",
|
| 154 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 155 |
+
self.act = get_tf_activation(config.hidden_act)
|
| 156 |
+
self.pwconv2 = keras.layers.Dense(
|
| 157 |
+
units=dim,
|
| 158 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 159 |
+
bias_initializer="zeros",
|
| 160 |
+
name="pwconv2",
|
| 161 |
+
)
|
| 162 |
+
# Using `layers.Activation` instead of `tf.identity` to better control `training`
|
| 163 |
+
# behaviour.
|
| 164 |
+
self.drop_path = (
|
| 165 |
+
TFConvNextDropPath(drop_path, name="drop_path")
|
| 166 |
+
if drop_path > 0.0
|
| 167 |
+
else keras.layers.Activation("linear", name="drop_path")
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
def build(self, input_shape: tf.TensorShape = None):
|
| 171 |
+
# PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa)
|
| 172 |
+
self.layer_scale_parameter = (
|
| 173 |
+
self.add_weight(
|
| 174 |
+
shape=(self.dim,),
|
| 175 |
+
initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
|
| 176 |
+
trainable=True,
|
| 177 |
+
name="layer_scale_parameter",
|
| 178 |
+
)
|
| 179 |
+
if self.config.layer_scale_init_value > 0
|
| 180 |
+
else None
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
if self.built:
|
| 184 |
+
return
|
| 185 |
+
self.built = True
|
| 186 |
+
if getattr(self, "dwconv", None) is not None:
|
| 187 |
+
with tf.name_scope(self.dwconv.name):
|
| 188 |
+
self.dwconv.build([None, None, None, self.dim])
|
| 189 |
+
if getattr(self, "layernorm", None) is not None:
|
| 190 |
+
with tf.name_scope(self.layernorm.name):
|
| 191 |
+
self.layernorm.build([None, None, None, self.dim])
|
| 192 |
+
if getattr(self, "pwconv1", None) is not None:
|
| 193 |
+
with tf.name_scope(self.pwconv1.name):
|
| 194 |
+
self.pwconv1.build([None, None, self.dim])
|
| 195 |
+
if getattr(self, "pwconv2", None) is not None:
|
| 196 |
+
with tf.name_scope(self.pwconv2.name):
|
| 197 |
+
self.pwconv2.build([None, None, 4 * self.dim])
|
| 198 |
+
if getattr(self, "drop_path", None) is not None:
|
| 199 |
+
with tf.name_scope(self.drop_path.name):
|
| 200 |
+
self.drop_path.build(None)
|
| 201 |
+
|
| 202 |
+
def call(self, hidden_states, training=False):
|
| 203 |
+
input = hidden_states
|
| 204 |
+
x = self.dwconv(hidden_states)
|
| 205 |
+
x = self.layernorm(x)
|
| 206 |
+
x = self.pwconv1(x)
|
| 207 |
+
x = self.act(x)
|
| 208 |
+
x = self.pwconv2(x)
|
| 209 |
+
|
| 210 |
+
if self.layer_scale_parameter is not None:
|
| 211 |
+
x = self.layer_scale_parameter * x
|
| 212 |
+
|
| 213 |
+
x = input + self.drop_path(x, training=training)
|
| 214 |
+
return x
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class TFConvNextStage(keras.layers.Layer):
|
| 218 |
+
"""ConvNext stage, consisting of an optional downsampling layer + multiple residual blocks.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
config (`ConvNextV2Config`):
|
| 222 |
+
Model configuration class.
|
| 223 |
+
in_channels (`int`):
|
| 224 |
+
Number of input channels.
|
| 225 |
+
out_channels (`int`):
|
| 226 |
+
Number of output channels.
|
| 227 |
+
depth (`int`):
|
| 228 |
+
Number of residual blocks.
|
| 229 |
+
drop_path_rates(`List[float]`):
|
| 230 |
+
Stochastic depth rates for each layer.
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
def __init__(
|
| 234 |
+
self,
|
| 235 |
+
config: ConvNextConfig,
|
| 236 |
+
in_channels: int,
|
| 237 |
+
out_channels: int,
|
| 238 |
+
kernel_size: int = 2,
|
| 239 |
+
stride: int = 2,
|
| 240 |
+
depth: int = 2,
|
| 241 |
+
drop_path_rates: Optional[List[float]] = None,
|
| 242 |
+
**kwargs,
|
| 243 |
+
):
|
| 244 |
+
super().__init__(**kwargs)
|
| 245 |
+
if in_channels != out_channels or stride > 1:
|
| 246 |
+
self.downsampling_layer = [
|
| 247 |
+
keras.layers.LayerNormalization(
|
| 248 |
+
epsilon=1e-6,
|
| 249 |
+
name="downsampling_layer.0",
|
| 250 |
+
),
|
| 251 |
+
# Inputs to this layer will follow NHWC format since we
|
| 252 |
+
# transposed the inputs from NCHW to NHWC in the `TFConvNextEmbeddings`
|
| 253 |
+
# layer. All the outputs throughout the model will be in NHWC
|
| 254 |
+
# from this point on until the output where we again change to
|
| 255 |
+
# NCHW.
|
| 256 |
+
keras.layers.Conv2D(
|
| 257 |
+
filters=out_channels,
|
| 258 |
+
kernel_size=kernel_size,
|
| 259 |
+
strides=stride,
|
| 260 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 261 |
+
bias_initializer=keras.initializers.Zeros(),
|
| 262 |
+
name="downsampling_layer.1",
|
| 263 |
+
),
|
| 264 |
+
]
|
| 265 |
+
else:
|
| 266 |
+
self.downsampling_layer = [tf.identity]
|
| 267 |
+
|
| 268 |
+
drop_path_rates = drop_path_rates or [0.0] * depth
|
| 269 |
+
self.layers = [
|
| 270 |
+
TFConvNextLayer(
|
| 271 |
+
config,
|
| 272 |
+
dim=out_channels,
|
| 273 |
+
drop_path=drop_path_rates[j],
|
| 274 |
+
name=f"layers.{j}",
|
| 275 |
+
)
|
| 276 |
+
for j in range(depth)
|
| 277 |
+
]
|
| 278 |
+
self.in_channels = in_channels
|
| 279 |
+
self.out_channels = out_channels
|
| 280 |
+
self.stride = stride
|
| 281 |
+
|
| 282 |
+
def call(self, hidden_states):
|
| 283 |
+
for layer in self.downsampling_layer:
|
| 284 |
+
hidden_states = layer(hidden_states)
|
| 285 |
+
for layer in self.layers:
|
| 286 |
+
hidden_states = layer(hidden_states)
|
| 287 |
+
return hidden_states
|
| 288 |
+
|
| 289 |
+
def build(self, input_shape=None):
|
| 290 |
+
if self.built:
|
| 291 |
+
return
|
| 292 |
+
self.built = True
|
| 293 |
+
if getattr(self, "layers", None) is not None:
|
| 294 |
+
for layer in self.layers:
|
| 295 |
+
with tf.name_scope(layer.name):
|
| 296 |
+
layer.build(None)
|
| 297 |
+
if self.in_channels != self.out_channels or self.stride > 1:
|
| 298 |
+
with tf.name_scope(self.downsampling_layer[0].name):
|
| 299 |
+
self.downsampling_layer[0].build([None, None, None, self.in_channels])
|
| 300 |
+
with tf.name_scope(self.downsampling_layer[1].name):
|
| 301 |
+
self.downsampling_layer[1].build([None, None, None, self.in_channels])
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class TFConvNextEncoder(keras.layers.Layer):
|
| 305 |
+
def __init__(self, config, **kwargs):
|
| 306 |
+
super().__init__(**kwargs)
|
| 307 |
+
self.stages = []
|
| 308 |
+
drop_path_rates = tf.linspace(0.0, config.drop_path_rate, sum(config.depths))
|
| 309 |
+
drop_path_rates = tf.split(drop_path_rates, config.depths)
|
| 310 |
+
drop_path_rates = [x.numpy().tolist() for x in drop_path_rates]
|
| 311 |
+
prev_chs = config.hidden_sizes[0]
|
| 312 |
+
for i in range(config.num_stages):
|
| 313 |
+
out_chs = config.hidden_sizes[i]
|
| 314 |
+
stage = TFConvNextStage(
|
| 315 |
+
config,
|
| 316 |
+
in_channels=prev_chs,
|
| 317 |
+
out_channels=out_chs,
|
| 318 |
+
stride=2 if i > 0 else 1,
|
| 319 |
+
depth=config.depths[i],
|
| 320 |
+
drop_path_rates=drop_path_rates[i],
|
| 321 |
+
name=f"stages.{i}",
|
| 322 |
+
)
|
| 323 |
+
self.stages.append(stage)
|
| 324 |
+
prev_chs = out_chs
|
| 325 |
+
|
| 326 |
+
def call(self, hidden_states, output_hidden_states=False, return_dict=True):
|
| 327 |
+
all_hidden_states = () if output_hidden_states else None
|
| 328 |
+
|
| 329 |
+
for i, layer_module in enumerate(self.stages):
|
| 330 |
+
if output_hidden_states:
|
| 331 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 332 |
+
|
| 333 |
+
hidden_states = layer_module(hidden_states)
|
| 334 |
+
|
| 335 |
+
if output_hidden_states:
|
| 336 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 337 |
+
|
| 338 |
+
if not return_dict:
|
| 339 |
+
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
|
| 340 |
+
|
| 341 |
+
return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
|
| 342 |
+
|
| 343 |
+
def build(self, input_shape=None):
|
| 344 |
+
for stage in self.stages:
|
| 345 |
+
with tf.name_scope(stage.name):
|
| 346 |
+
stage.build(None)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
@keras_serializable
|
| 350 |
+
class TFConvNextMainLayer(keras.layers.Layer):
|
| 351 |
+
config_class = ConvNextConfig
|
| 352 |
+
|
| 353 |
+
def __init__(self, config: ConvNextConfig, add_pooling_layer: bool = True, **kwargs):
|
| 354 |
+
super().__init__(**kwargs)
|
| 355 |
+
|
| 356 |
+
self.config = config
|
| 357 |
+
self.embeddings = TFConvNextEmbeddings(config, name="embeddings")
|
| 358 |
+
self.encoder = TFConvNextEncoder(config, name="encoder")
|
| 359 |
+
self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
|
| 360 |
+
# We are setting the `data_format` like so because from here on we will revert to the
|
| 361 |
+
# NCHW output format
|
| 362 |
+
self.pooler = keras.layers.GlobalAvgPool2D(data_format="channels_first") if add_pooling_layer else None
|
| 363 |
+
|
| 364 |
+
@unpack_inputs
|
| 365 |
+
def call(
|
| 366 |
+
self,
|
| 367 |
+
pixel_values: TFModelInputType | None = None,
|
| 368 |
+
output_hidden_states: Optional[bool] = None,
|
| 369 |
+
return_dict: Optional[bool] = None,
|
| 370 |
+
training: bool = False,
|
| 371 |
+
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
| 372 |
+
output_hidden_states = (
|
| 373 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 374 |
+
)
|
| 375 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 376 |
+
|
| 377 |
+
if pixel_values is None:
|
| 378 |
+
raise ValueError("You have to specify pixel_values")
|
| 379 |
+
|
| 380 |
+
embedding_output = self.embeddings(pixel_values, training=training)
|
| 381 |
+
|
| 382 |
+
encoder_outputs = self.encoder(
|
| 383 |
+
embedding_output,
|
| 384 |
+
output_hidden_states=output_hidden_states,
|
| 385 |
+
return_dict=return_dict,
|
| 386 |
+
training=training,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
last_hidden_state = encoder_outputs[0]
|
| 390 |
+
# Change to NCHW output format have uniformity in the modules
|
| 391 |
+
last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
|
| 392 |
+
pooled_output = self.layernorm(self.pooler(last_hidden_state))
|
| 393 |
+
|
| 394 |
+
# Change the other hidden state outputs to NCHW as well
|
| 395 |
+
if output_hidden_states:
|
| 396 |
+
hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
|
| 397 |
+
|
| 398 |
+
if not return_dict:
|
| 399 |
+
hidden_states = hidden_states if output_hidden_states else ()
|
| 400 |
+
return (last_hidden_state, pooled_output) + hidden_states
|
| 401 |
+
|
| 402 |
+
return TFBaseModelOutputWithPooling(
|
| 403 |
+
last_hidden_state=last_hidden_state,
|
| 404 |
+
pooler_output=pooled_output,
|
| 405 |
+
hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
def build(self, input_shape=None):
|
| 409 |
+
if self.built:
|
| 410 |
+
return
|
| 411 |
+
self.built = True
|
| 412 |
+
if getattr(self, "embeddings", None) is not None:
|
| 413 |
+
with tf.name_scope(self.embeddings.name):
|
| 414 |
+
self.embeddings.build(None)
|
| 415 |
+
if getattr(self, "encoder", None) is not None:
|
| 416 |
+
with tf.name_scope(self.encoder.name):
|
| 417 |
+
self.encoder.build(None)
|
| 418 |
+
if getattr(self, "layernorm", None) is not None:
|
| 419 |
+
with tf.name_scope(self.layernorm.name):
|
| 420 |
+
self.layernorm.build([None, self.config.hidden_sizes[-1]])
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
class TFConvNextPreTrainedModel(TFPreTrainedModel):
|
| 424 |
+
"""
|
| 425 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 426 |
+
models.
|
| 427 |
+
"""
|
| 428 |
+
|
| 429 |
+
config_class = ConvNextConfig
|
| 430 |
+
base_model_prefix = "convnext"
|
| 431 |
+
main_input_name = "pixel_values"
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
CONVNEXT_START_DOCSTRING = r"""
|
| 435 |
+
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 436 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 437 |
+
etc.)
|
| 438 |
+
|
| 439 |
+
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
| 440 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
| 441 |
+
behavior.
|
| 442 |
+
|
| 443 |
+
<Tip>
|
| 444 |
+
|
| 445 |
+
TensorFlow models and layers in `transformers` accept two formats as input:
|
| 446 |
+
|
| 447 |
+
- having all inputs as keyword arguments (like PyTorch models), or
|
| 448 |
+
- having all inputs as a list, tuple or dict in the first positional argument.
|
| 449 |
+
|
| 450 |
+
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
|
| 451 |
+
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
|
| 452 |
+
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
|
| 453 |
+
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
|
| 454 |
+
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
|
| 455 |
+
positional argument:
|
| 456 |
+
|
| 457 |
+
- a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
|
| 458 |
+
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
| 459 |
+
`model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
|
| 460 |
+
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
| 461 |
+
`model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
|
| 462 |
+
|
| 463 |
+
Note that when creating models and layers with
|
| 464 |
+
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
|
| 465 |
+
about any of this, as you can just pass inputs like you would to any other Python function!
|
| 466 |
+
|
| 467 |
+
</Tip>
|
| 468 |
+
|
| 469 |
+
Parameters:
|
| 470 |
+
config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model.
|
| 471 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 472 |
+
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
|
| 473 |
+
"""
|
| 474 |
+
|
| 475 |
+
CONVNEXT_INPUTS_DOCSTRING = r"""
|
| 476 |
+
Args:
|
| 477 |
+
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
|
| 478 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
| 479 |
+
[`ConvNextImageProcessor.__call__`] for details.
|
| 480 |
+
|
| 481 |
+
output_hidden_states (`bool`, *optional*):
|
| 482 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 483 |
+
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
| 484 |
+
used instead.
|
| 485 |
+
return_dict (`bool`, *optional*):
|
| 486 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
|
| 487 |
+
eager mode, in graph mode the value will always be set to True.
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
@add_start_docstrings(
|
| 492 |
+
"The bare ConvNext model outputting raw features without any specific head on top.",
|
| 493 |
+
CONVNEXT_START_DOCSTRING,
|
| 494 |
+
)
|
| 495 |
+
class TFConvNextModel(TFConvNextPreTrainedModel):
|
| 496 |
+
def __init__(self, config, *inputs, add_pooling_layer=True, **kwargs):
|
| 497 |
+
super().__init__(config, *inputs, **kwargs)
|
| 498 |
+
self.convnext = TFConvNextMainLayer(config, add_pooling_layer=add_pooling_layer, name="convnext")
|
| 499 |
+
|
| 500 |
+
@unpack_inputs
|
| 501 |
+
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
|
| 502 |
+
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
| 503 |
+
def call(
|
| 504 |
+
self,
|
| 505 |
+
pixel_values: TFModelInputType | None = None,
|
| 506 |
+
output_hidden_states: Optional[bool] = None,
|
| 507 |
+
return_dict: Optional[bool] = None,
|
| 508 |
+
training: bool = False,
|
| 509 |
+
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
| 510 |
+
r"""
|
| 511 |
+
Returns:
|
| 512 |
+
|
| 513 |
+
Examples:
|
| 514 |
+
|
| 515 |
+
```python
|
| 516 |
+
>>> from transformers import AutoImageProcessor, TFConvNextModel
|
| 517 |
+
>>> from PIL import Image
|
| 518 |
+
>>> import requests
|
| 519 |
+
|
| 520 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 521 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 522 |
+
|
| 523 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
|
| 524 |
+
>>> model = TFConvNextModel.from_pretrained("facebook/convnext-tiny-224")
|
| 525 |
+
|
| 526 |
+
>>> inputs = image_processor(images=image, return_tensors="tf")
|
| 527 |
+
>>> outputs = model(**inputs)
|
| 528 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 529 |
+
```"""
|
| 530 |
+
output_hidden_states = (
|
| 531 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 532 |
+
)
|
| 533 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 534 |
+
|
| 535 |
+
if pixel_values is None:
|
| 536 |
+
raise ValueError("You have to specify pixel_values")
|
| 537 |
+
|
| 538 |
+
outputs = self.convnext(
|
| 539 |
+
pixel_values=pixel_values,
|
| 540 |
+
output_hidden_states=output_hidden_states,
|
| 541 |
+
return_dict=return_dict,
|
| 542 |
+
training=training,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
if not return_dict:
|
| 546 |
+
return (outputs[0],) + outputs[1:]
|
| 547 |
+
|
| 548 |
+
return TFBaseModelOutputWithPooling(
|
| 549 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 550 |
+
pooler_output=outputs.pooler_output,
|
| 551 |
+
hidden_states=outputs.hidden_states,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
def build(self, input_shape=None):
|
| 555 |
+
if self.built:
|
| 556 |
+
return
|
| 557 |
+
self.built = True
|
| 558 |
+
if getattr(self, "convnext", None) is not None:
|
| 559 |
+
with tf.name_scope(self.convnext.name):
|
| 560 |
+
self.convnext.build(None)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
@add_start_docstrings(
|
| 564 |
+
"""
|
| 565 |
+
ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
|
| 566 |
+
ImageNet.
|
| 567 |
+
""",
|
| 568 |
+
CONVNEXT_START_DOCSTRING,
|
| 569 |
+
)
|
| 570 |
+
class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClassificationLoss):
|
| 571 |
+
def __init__(self, config: ConvNextConfig, *inputs, **kwargs):
|
| 572 |
+
super().__init__(config, *inputs, **kwargs)
|
| 573 |
+
|
| 574 |
+
self.num_labels = config.num_labels
|
| 575 |
+
self.convnext = TFConvNextMainLayer(config, name="convnext")
|
| 576 |
+
|
| 577 |
+
# Classifier head
|
| 578 |
+
self.classifier = keras.layers.Dense(
|
| 579 |
+
units=config.num_labels,
|
| 580 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 581 |
+
bias_initializer="zeros",
|
| 582 |
+
name="classifier",
|
| 583 |
+
)
|
| 584 |
+
self.config = config
|
| 585 |
+
|
| 586 |
+
@unpack_inputs
|
| 587 |
+
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
|
| 588 |
+
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
| 589 |
+
def call(
|
| 590 |
+
self,
|
| 591 |
+
pixel_values: TFModelInputType | None = None,
|
| 592 |
+
output_hidden_states: Optional[bool] = None,
|
| 593 |
+
return_dict: Optional[bool] = None,
|
| 594 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 595 |
+
training: Optional[bool] = False,
|
| 596 |
+
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
|
| 597 |
+
r"""
|
| 598 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
| 599 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 600 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 601 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 602 |
+
|
| 603 |
+
Returns:
|
| 604 |
+
|
| 605 |
+
Examples:
|
| 606 |
+
|
| 607 |
+
```python
|
| 608 |
+
>>> from transformers import AutoImageProcessor, TFConvNextForImageClassification
|
| 609 |
+
>>> import tensorflow as tf
|
| 610 |
+
>>> from PIL import Image
|
| 611 |
+
>>> import requests
|
| 612 |
+
|
| 613 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 614 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 615 |
+
|
| 616 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
|
| 617 |
+
>>> model = TFConvNextForImageClassification.from_pretrained("facebook/convnext-tiny-224")
|
| 618 |
+
|
| 619 |
+
>>> inputs = image_processor(images=image, return_tensors="tf")
|
| 620 |
+
>>> outputs = model(**inputs)
|
| 621 |
+
>>> logits = outputs.logits
|
| 622 |
+
>>> # model predicts one of the 1000 ImageNet classes
|
| 623 |
+
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
|
| 624 |
+
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
|
| 625 |
+
```"""
|
| 626 |
+
output_hidden_states = (
|
| 627 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 628 |
+
)
|
| 629 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 630 |
+
|
| 631 |
+
if pixel_values is None:
|
| 632 |
+
raise ValueError("You have to specify pixel_values")
|
| 633 |
+
|
| 634 |
+
outputs = self.convnext(
|
| 635 |
+
pixel_values,
|
| 636 |
+
output_hidden_states=output_hidden_states,
|
| 637 |
+
return_dict=return_dict,
|
| 638 |
+
training=training,
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
pooled_output = outputs.pooler_output if return_dict else outputs[1]
|
| 642 |
+
|
| 643 |
+
logits = self.classifier(pooled_output)
|
| 644 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
| 645 |
+
|
| 646 |
+
if not return_dict:
|
| 647 |
+
output = (logits,) + outputs[2:]
|
| 648 |
+
return ((loss,) + output) if loss is not None else output
|
| 649 |
+
|
| 650 |
+
return TFSequenceClassifierOutput(
|
| 651 |
+
loss=loss,
|
| 652 |
+
logits=logits,
|
| 653 |
+
hidden_states=outputs.hidden_states,
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
def build(self, input_shape=None):
|
| 657 |
+
if self.built:
|
| 658 |
+
return
|
| 659 |
+
self.built = True
|
| 660 |
+
if getattr(self, "convnext", None) is not None:
|
| 661 |
+
with tf.name_scope(self.convnext.name):
|
| 662 |
+
self.convnext.build(None)
|
| 663 |
+
if getattr(self, "classifier", None) is not None:
|
| 664 |
+
if hasattr(self.classifier, "name"):
|
| 665 |
+
with tf.name_scope(self.classifier.name):
|
| 666 |
+
self.classifier.build([None, None, self.config.hidden_sizes[-1]])
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
__all__ = ["TFConvNextForImageClassification", "TFConvNextModel", "TFConvNextPreTrainedModel"]
|
.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_decision_transformer import *
|
| 22 |
+
from .modeling_decision_transformer import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (808 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/configuration_decision_transformer.cpython-311.pyc
ADDED
|
Binary file (6.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/modeling_decision_transformer.cpython-311.pyc
ADDED
|
Binary file (47.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/modeling_decision_transformer.py
ADDED
|
@@ -0,0 +1,963 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Team The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch DecisionTransformer model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import os
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Callable, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.utils.checkpoint
|
| 24 |
+
from torch import nn
|
| 25 |
+
|
| 26 |
+
from ...activations import ACT2FN
|
| 27 |
+
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
| 28 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 29 |
+
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
| 30 |
+
from ...utils import (
|
| 31 |
+
ModelOutput,
|
| 32 |
+
add_start_docstrings,
|
| 33 |
+
add_start_docstrings_to_model_forward,
|
| 34 |
+
logging,
|
| 35 |
+
replace_return_docstrings,
|
| 36 |
+
)
|
| 37 |
+
from .configuration_decision_transformer import DecisionTransformerConfig
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__)
|
| 41 |
+
|
| 42 |
+
_CHECKPOINT_FOR_DOC = "edbeeching/decision-transformer-gym-hopper-medium"
|
| 43 |
+
_CONFIG_FOR_DOC = "DecisionTransformerConfig"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Copied from transformers.models.gpt2.modeling_gpt2.load_tf_weights_in_gpt2
|
| 47 |
+
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
| 48 |
+
"""Load tf checkpoints in a pytorch model"""
|
| 49 |
+
try:
|
| 50 |
+
import re
|
| 51 |
+
|
| 52 |
+
import tensorflow as tf
|
| 53 |
+
except ImportError:
|
| 54 |
+
logger.error(
|
| 55 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
| 56 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
| 57 |
+
)
|
| 58 |
+
raise
|
| 59 |
+
tf_path = os.path.abspath(gpt2_checkpoint_path)
|
| 60 |
+
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
| 61 |
+
# Load weights from TF model
|
| 62 |
+
init_vars = tf.train.list_variables(tf_path)
|
| 63 |
+
names = []
|
| 64 |
+
arrays = []
|
| 65 |
+
for name, shape in init_vars:
|
| 66 |
+
logger.info(f"Loading TF weight {name} with shape {shape}")
|
| 67 |
+
array = tf.train.load_variable(tf_path, name)
|
| 68 |
+
names.append(name)
|
| 69 |
+
arrays.append(array.squeeze())
|
| 70 |
+
|
| 71 |
+
for name, array in zip(names, arrays):
|
| 72 |
+
name = name[6:] # skip "model/"
|
| 73 |
+
name = name.split("/")
|
| 74 |
+
pointer = model
|
| 75 |
+
for m_name in name:
|
| 76 |
+
if re.fullmatch(r"[A-Za-z]+\d+", m_name):
|
| 77 |
+
scope_names = re.split(r"(\d+)", m_name)
|
| 78 |
+
else:
|
| 79 |
+
scope_names = [m_name]
|
| 80 |
+
if scope_names[0] == "w" or scope_names[0] == "g":
|
| 81 |
+
pointer = getattr(pointer, "weight")
|
| 82 |
+
elif scope_names[0] == "b":
|
| 83 |
+
pointer = getattr(pointer, "bias")
|
| 84 |
+
elif scope_names[0] == "wpe" or scope_names[0] == "wte":
|
| 85 |
+
pointer = getattr(pointer, scope_names[0])
|
| 86 |
+
pointer = getattr(pointer, "weight")
|
| 87 |
+
else:
|
| 88 |
+
pointer = getattr(pointer, scope_names[0])
|
| 89 |
+
if len(scope_names) >= 2:
|
| 90 |
+
num = int(scope_names[1])
|
| 91 |
+
pointer = pointer[num]
|
| 92 |
+
try:
|
| 93 |
+
if pointer.shape != array.shape:
|
| 94 |
+
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
| 95 |
+
except ValueError as e:
|
| 96 |
+
e.args += (pointer.shape, array.shape)
|
| 97 |
+
raise
|
| 98 |
+
logger.info(f"Initialize PyTorch weight {name}")
|
| 99 |
+
pointer.data = torch.from_numpy(array)
|
| 100 |
+
return model
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# Copied from transformers.models.gpt2.modeling_gpt2.eager_attention_forward
|
| 104 |
+
def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs):
|
| 105 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
| 106 |
+
|
| 107 |
+
if module.scale_attn_weights:
|
| 108 |
+
attn_weights = attn_weights / torch.full(
|
| 109 |
+
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Layer-wise attention scaling
|
| 113 |
+
if module.scale_attn_by_inverse_layer_idx:
|
| 114 |
+
attn_weights = attn_weights / float(module.layer_idx + 1)
|
| 115 |
+
|
| 116 |
+
if not module.is_cross_attention:
|
| 117 |
+
# if only "normal" attention layer implements causal mask
|
| 118 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
| 119 |
+
causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
|
| 120 |
+
mask_value = torch.finfo(attn_weights.dtype).min
|
| 121 |
+
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
| 122 |
+
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
| 123 |
+
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
| 124 |
+
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
|
| 125 |
+
|
| 126 |
+
if attention_mask is not None:
|
| 127 |
+
# Apply the attention mask
|
| 128 |
+
attn_weights = attn_weights + attention_mask
|
| 129 |
+
|
| 130 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 131 |
+
|
| 132 |
+
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
|
| 133 |
+
attn_weights = attn_weights.type(value.dtype)
|
| 134 |
+
attn_weights = module.attn_dropout(attn_weights)
|
| 135 |
+
|
| 136 |
+
# Mask heads if we want to
|
| 137 |
+
if head_mask is not None:
|
| 138 |
+
attn_weights = attn_weights * head_mask
|
| 139 |
+
|
| 140 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 141 |
+
attn_output = attn_output.transpose(1, 2)
|
| 142 |
+
|
| 143 |
+
return attn_output, attn_weights
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->DecisionTransformerGPT2
|
| 147 |
+
class DecisionTransformerGPT2Attention(nn.Module):
|
| 148 |
+
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
| 149 |
+
super().__init__()
|
| 150 |
+
self.config = config
|
| 151 |
+
max_positions = config.max_position_embeddings
|
| 152 |
+
self.register_buffer(
|
| 153 |
+
"bias",
|
| 154 |
+
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
| 155 |
+
1, 1, max_positions, max_positions
|
| 156 |
+
),
|
| 157 |
+
persistent=False,
|
| 158 |
+
)
|
| 159 |
+
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
|
| 160 |
+
|
| 161 |
+
self.embed_dim = config.hidden_size
|
| 162 |
+
self.num_heads = config.num_attention_heads
|
| 163 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 164 |
+
self.split_size = self.embed_dim
|
| 165 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 166 |
+
raise ValueError(
|
| 167 |
+
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 168 |
+
f" {self.num_heads})."
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
self.scale_attn_weights = config.scale_attn_weights
|
| 172 |
+
self.is_cross_attention = is_cross_attention
|
| 173 |
+
|
| 174 |
+
# Layer-wise attention scaling, reordering, and upcasting
|
| 175 |
+
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
|
| 176 |
+
self.layer_idx = layer_idx
|
| 177 |
+
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
|
| 178 |
+
|
| 179 |
+
if self.is_cross_attention:
|
| 180 |
+
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
|
| 181 |
+
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
|
| 182 |
+
else:
|
| 183 |
+
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
| 184 |
+
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
| 185 |
+
|
| 186 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
| 187 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
| 188 |
+
self.is_causal = True
|
| 189 |
+
|
| 190 |
+
self.pruned_heads = set()
|
| 191 |
+
|
| 192 |
+
def prune_heads(self, heads):
|
| 193 |
+
if len(heads) == 0:
|
| 194 |
+
return
|
| 195 |
+
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
|
| 196 |
+
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
| 197 |
+
|
| 198 |
+
# Prune conv1d layers
|
| 199 |
+
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
| 200 |
+
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
| 201 |
+
|
| 202 |
+
# Update hyper params
|
| 203 |
+
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
|
| 204 |
+
self.num_heads = self.num_heads - len(heads)
|
| 205 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 206 |
+
|
| 207 |
+
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
|
| 208 |
+
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
|
| 209 |
+
bsz, num_heads, q_seq_len, dk = query.size()
|
| 210 |
+
_, _, k_seq_len, _ = key.size()
|
| 211 |
+
|
| 212 |
+
# Preallocate attn_weights for `baddbmm`
|
| 213 |
+
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
|
| 214 |
+
|
| 215 |
+
# Compute Scale Factor
|
| 216 |
+
scale_factor = 1.0
|
| 217 |
+
if self.scale_attn_weights:
|
| 218 |
+
scale_factor /= float(value.size(-1)) ** 0.5
|
| 219 |
+
|
| 220 |
+
if self.scale_attn_by_inverse_layer_idx:
|
| 221 |
+
scale_factor /= float(self.layer_idx + 1)
|
| 222 |
+
|
| 223 |
+
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
| 224 |
+
with torch.amp.autocast(query.device.type, enabled=False):
|
| 225 |
+
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
| 226 |
+
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
| 227 |
+
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
| 228 |
+
|
| 229 |
+
if not self.is_cross_attention:
|
| 230 |
+
# if only "normal" attention layer implements causal mask
|
| 231 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
| 232 |
+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
| 233 |
+
mask_value = torch.finfo(attn_weights.dtype).min
|
| 234 |
+
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
| 235 |
+
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
| 236 |
+
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
| 237 |
+
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
| 238 |
+
|
| 239 |
+
if attention_mask is not None:
|
| 240 |
+
# Apply the attention mask
|
| 241 |
+
attn_weights = attn_weights + attention_mask
|
| 242 |
+
|
| 243 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 244 |
+
|
| 245 |
+
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
|
| 246 |
+
if attn_weights.dtype != torch.float32:
|
| 247 |
+
raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
|
| 248 |
+
attn_weights = attn_weights.type(value.dtype)
|
| 249 |
+
attn_weights = self.attn_dropout(attn_weights)
|
| 250 |
+
|
| 251 |
+
# Mask heads if we want to
|
| 252 |
+
if head_mask is not None:
|
| 253 |
+
attn_weights = attn_weights * head_mask
|
| 254 |
+
|
| 255 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 256 |
+
attn_output = attn_output.transpose(1, 2)
|
| 257 |
+
|
| 258 |
+
return attn_output, attn_weights
|
| 259 |
+
|
| 260 |
+
def forward(
|
| 261 |
+
self,
|
| 262 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
| 263 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
| 264 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 265 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 266 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 267 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 268 |
+
use_cache: Optional[bool] = False,
|
| 269 |
+
output_attentions: Optional[bool] = False,
|
| 270 |
+
**kwargs,
|
| 271 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
| 272 |
+
if encoder_hidden_states is not None:
|
| 273 |
+
if not hasattr(self, "q_attn"):
|
| 274 |
+
raise ValueError(
|
| 275 |
+
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
| 276 |
+
"Please make sure to instantiate class with `DecisionTransformerGPT2Attention(..., is_cross_attention=True)`."
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
query_states = self.q_attn(hidden_states)
|
| 280 |
+
key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
| 281 |
+
attention_mask = encoder_attention_mask
|
| 282 |
+
else:
|
| 283 |
+
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
| 284 |
+
|
| 285 |
+
shape_q = (*query_states.shape[:-1], -1, self.head_dim)
|
| 286 |
+
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
|
| 287 |
+
|
| 288 |
+
query_states = query_states.view(shape_q).transpose(1, 2)
|
| 289 |
+
key_states = key_states.view(shape_kv).transpose(1, 2)
|
| 290 |
+
value_states = value_states.view(shape_kv).transpose(1, 2)
|
| 291 |
+
|
| 292 |
+
if layer_past is not None:
|
| 293 |
+
past_key, past_value = layer_past
|
| 294 |
+
key_states = torch.cat((past_key, key_states), dim=-2)
|
| 295 |
+
value_states = torch.cat((past_value, value_states), dim=-2)
|
| 296 |
+
|
| 297 |
+
if use_cache is True:
|
| 298 |
+
present = (key_states, value_states)
|
| 299 |
+
else:
|
| 300 |
+
present = None
|
| 301 |
+
|
| 302 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 303 |
+
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
|
| 304 |
+
|
| 305 |
+
using_eager = self.config._attn_implementation == "eager"
|
| 306 |
+
attention_interface: Callable = eager_attention_forward
|
| 307 |
+
if self.config._attn_implementation != "eager":
|
| 308 |
+
if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
|
| 309 |
+
using_eager = True
|
| 310 |
+
logger.warning_once(
|
| 311 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
| 312 |
+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 313 |
+
)
|
| 314 |
+
else:
|
| 315 |
+
# Attention functions are consistent with previous equivalent attention classes, however they do not support some options
|
| 316 |
+
# (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
|
| 317 |
+
# not necessarily to eager (if mentionned options are provided).
|
| 318 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 319 |
+
|
| 320 |
+
if using_eager and self.reorder_and_upcast_attn:
|
| 321 |
+
attn_output, attn_weights = self._upcast_and_reordered_attn(
|
| 322 |
+
query_states, key_states, value_states, attention_mask, head_mask
|
| 323 |
+
)
|
| 324 |
+
else:
|
| 325 |
+
attn_output, attn_weights = attention_interface(
|
| 326 |
+
self,
|
| 327 |
+
query_states,
|
| 328 |
+
key_states,
|
| 329 |
+
value_states,
|
| 330 |
+
attention_mask,
|
| 331 |
+
head_mask=head_mask,
|
| 332 |
+
dropout=self.attn_dropout.p if self.training else 0.0,
|
| 333 |
+
is_causal=is_causal,
|
| 334 |
+
**kwargs,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
|
| 338 |
+
attn_output = self.c_proj(attn_output)
|
| 339 |
+
attn_output = self.resid_dropout(attn_output)
|
| 340 |
+
|
| 341 |
+
outputs = (attn_output, present)
|
| 342 |
+
if output_attentions:
|
| 343 |
+
outputs += (attn_weights,)
|
| 344 |
+
|
| 345 |
+
return outputs # a, present, (attentions)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->DecisionTransformerGPT2
|
| 349 |
+
class DecisionTransformerGPT2MLP(nn.Module):
|
| 350 |
+
def __init__(self, intermediate_size, config):
|
| 351 |
+
super().__init__()
|
| 352 |
+
embed_dim = config.hidden_size
|
| 353 |
+
self.c_fc = Conv1D(intermediate_size, embed_dim)
|
| 354 |
+
self.c_proj = Conv1D(embed_dim, intermediate_size)
|
| 355 |
+
self.act = ACT2FN[config.activation_function]
|
| 356 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
| 357 |
+
|
| 358 |
+
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
| 359 |
+
hidden_states = self.c_fc(hidden_states)
|
| 360 |
+
hidden_states = self.act(hidden_states)
|
| 361 |
+
hidden_states = self.c_proj(hidden_states)
|
| 362 |
+
hidden_states = self.dropout(hidden_states)
|
| 363 |
+
return hidden_states
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2
|
| 367 |
+
class DecisionTransformerGPT2Block(nn.Module):
|
| 368 |
+
# Ignore copy
|
| 369 |
+
def __init__(self, config, layer_idx=None):
|
| 370 |
+
super().__init__()
|
| 371 |
+
hidden_size = config.hidden_size
|
| 372 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
| 373 |
+
|
| 374 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 375 |
+
self.attn = DecisionTransformerGPT2Attention(config, layer_idx=layer_idx)
|
| 376 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 377 |
+
|
| 378 |
+
if config.add_cross_attention:
|
| 379 |
+
self.crossattention = DecisionTransformerGPT2Attention(
|
| 380 |
+
config, is_cross_attention=True, layer_idx=layer_idx
|
| 381 |
+
)
|
| 382 |
+
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 383 |
+
|
| 384 |
+
self.mlp = DecisionTransformerGPT2MLP(inner_dim, config)
|
| 385 |
+
|
| 386 |
+
def forward(
|
| 387 |
+
self,
|
| 388 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
| 389 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
| 390 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 391 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 392 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 393 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 394 |
+
use_cache: Optional[bool] = False,
|
| 395 |
+
output_attentions: Optional[bool] = False,
|
| 396 |
+
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
| 397 |
+
residual = hidden_states
|
| 398 |
+
hidden_states = self.ln_1(hidden_states)
|
| 399 |
+
attn_outputs = self.attn(
|
| 400 |
+
hidden_states,
|
| 401 |
+
layer_past=layer_past,
|
| 402 |
+
attention_mask=attention_mask,
|
| 403 |
+
head_mask=head_mask,
|
| 404 |
+
use_cache=use_cache,
|
| 405 |
+
output_attentions=output_attentions,
|
| 406 |
+
)
|
| 407 |
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
| 408 |
+
outputs = attn_outputs[1:]
|
| 409 |
+
# residual connection
|
| 410 |
+
hidden_states = attn_output + residual
|
| 411 |
+
|
| 412 |
+
if encoder_hidden_states is not None:
|
| 413 |
+
# add one self-attention block for cross-attention
|
| 414 |
+
if not hasattr(self, "crossattention"):
|
| 415 |
+
raise ValueError(
|
| 416 |
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
|
| 417 |
+
"cross-attention layers by setting `config.add_cross_attention=True`"
|
| 418 |
+
)
|
| 419 |
+
residual = hidden_states
|
| 420 |
+
hidden_states = self.ln_cross_attn(hidden_states)
|
| 421 |
+
cross_attn_outputs = self.crossattention(
|
| 422 |
+
hidden_states,
|
| 423 |
+
attention_mask=attention_mask,
|
| 424 |
+
head_mask=head_mask,
|
| 425 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 426 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 427 |
+
output_attentions=output_attentions,
|
| 428 |
+
)
|
| 429 |
+
attn_output = cross_attn_outputs[0]
|
| 430 |
+
# residual connection
|
| 431 |
+
hidden_states = residual + attn_output
|
| 432 |
+
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
| 433 |
+
|
| 434 |
+
residual = hidden_states
|
| 435 |
+
hidden_states = self.ln_2(hidden_states)
|
| 436 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
| 437 |
+
# residual connection
|
| 438 |
+
hidden_states = residual + feed_forward_hidden_states
|
| 439 |
+
|
| 440 |
+
if use_cache:
|
| 441 |
+
outputs = (hidden_states,) + outputs
|
| 442 |
+
else:
|
| 443 |
+
outputs = (hidden_states,) + outputs[1:]
|
| 444 |
+
|
| 445 |
+
return outputs # hidden_states, present, (attentions, cross_attentions)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
|
| 449 |
+
"""
|
| 450 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 451 |
+
models.
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
config_class = DecisionTransformerConfig
|
| 455 |
+
load_tf_weights = load_tf_weights_in_gpt2
|
| 456 |
+
base_model_prefix = "transformer"
|
| 457 |
+
is_parallelizable = True
|
| 458 |
+
supports_gradient_checkpointing = True
|
| 459 |
+
|
| 460 |
+
def __init__(self, *inputs, **kwargs):
|
| 461 |
+
super().__init__(*inputs, **kwargs)
|
| 462 |
+
|
| 463 |
+
def _init_weights(self, module):
|
| 464 |
+
"""Initialize the weights."""
|
| 465 |
+
if isinstance(module, (nn.Linear, Conv1D)):
|
| 466 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 467 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 468 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 469 |
+
if module.bias is not None:
|
| 470 |
+
module.bias.data.zero_()
|
| 471 |
+
elif isinstance(module, nn.Embedding):
|
| 472 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 473 |
+
if module.padding_idx is not None:
|
| 474 |
+
module.weight.data[module.padding_idx].zero_()
|
| 475 |
+
elif isinstance(module, nn.LayerNorm):
|
| 476 |
+
module.bias.data.zero_()
|
| 477 |
+
module.weight.data.fill_(1.0)
|
| 478 |
+
|
| 479 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
| 480 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
| 481 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
| 482 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
| 483 |
+
#
|
| 484 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
| 485 |
+
for name, p in module.named_parameters():
|
| 486 |
+
if "c_proj" in name and "weight" in name:
|
| 487 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
| 488 |
+
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
| 492 |
+
def __init__(self, config):
|
| 493 |
+
super().__init__(config)
|
| 494 |
+
|
| 495 |
+
self.embed_dim = config.hidden_size
|
| 496 |
+
|
| 497 |
+
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
| 498 |
+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
| 499 |
+
|
| 500 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
| 501 |
+
self.h = nn.ModuleList(
|
| 502 |
+
[DecisionTransformerGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
| 503 |
+
)
|
| 504 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
| 505 |
+
|
| 506 |
+
# Model parallel
|
| 507 |
+
self.model_parallel = False
|
| 508 |
+
self.device_map = None
|
| 509 |
+
self.gradient_checkpointing = False
|
| 510 |
+
|
| 511 |
+
# Initialize weights and apply final processing
|
| 512 |
+
self.post_init()
|
| 513 |
+
|
| 514 |
+
def get_input_embeddings(self):
|
| 515 |
+
return self.wte
|
| 516 |
+
|
| 517 |
+
def set_input_embeddings(self, new_embeddings):
|
| 518 |
+
self.wte = new_embeddings
|
| 519 |
+
|
| 520 |
+
def forward(
|
| 521 |
+
self,
|
| 522 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 523 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 524 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 525 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 526 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 527 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 528 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 529 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 530 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 531 |
+
use_cache: Optional[bool] = None,
|
| 532 |
+
output_attentions: Optional[bool] = None,
|
| 533 |
+
output_hidden_states: Optional[bool] = None,
|
| 534 |
+
return_dict: Optional[bool] = None,
|
| 535 |
+
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
| 536 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 537 |
+
output_hidden_states = (
|
| 538 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 539 |
+
)
|
| 540 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 541 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 542 |
+
|
| 543 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 544 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 545 |
+
elif input_ids is not None:
|
| 546 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 547 |
+
input_shape = input_ids.size()
|
| 548 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 549 |
+
batch_size = input_ids.shape[0]
|
| 550 |
+
elif inputs_embeds is not None:
|
| 551 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 552 |
+
batch_size = inputs_embeds.shape[0]
|
| 553 |
+
else:
|
| 554 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 555 |
+
|
| 556 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 557 |
+
|
| 558 |
+
if token_type_ids is not None:
|
| 559 |
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
| 560 |
+
|
| 561 |
+
if past_key_values is None:
|
| 562 |
+
past_length = 0
|
| 563 |
+
past_key_values = tuple([None] * len(self.h))
|
| 564 |
+
else:
|
| 565 |
+
past_length = past_key_values[0][0].size(-2)
|
| 566 |
+
if position_ids is None:
|
| 567 |
+
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
| 568 |
+
position_ids = position_ids.unsqueeze(0)
|
| 569 |
+
|
| 570 |
+
# Attention mask.
|
| 571 |
+
if attention_mask is not None:
|
| 572 |
+
if batch_size <= 0:
|
| 573 |
+
raise ValueError("batch_size has to be defined and > 0")
|
| 574 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
| 575 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 576 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 577 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 578 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 579 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 580 |
+
attention_mask = attention_mask[:, None, None, :]
|
| 581 |
+
|
| 582 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 583 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 584 |
+
# positions we want to attend and the dtype's smallest value for masked positions.
|
| 585 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 586 |
+
# effectively the same as removing these entirely.
|
| 587 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 588 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
| 589 |
+
|
| 590 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 591 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 592 |
+
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
| 593 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 594 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 595 |
+
if encoder_attention_mask is None:
|
| 596 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 597 |
+
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 598 |
+
else:
|
| 599 |
+
encoder_attention_mask = None
|
| 600 |
+
|
| 601 |
+
# Prepare head mask if needed
|
| 602 |
+
# 1.0 in head_mask indicate we keep the head
|
| 603 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 604 |
+
# head_mask has shape n_layer x batch x n_heads x N x N
|
| 605 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
| 606 |
+
|
| 607 |
+
if inputs_embeds is None:
|
| 608 |
+
inputs_embeds = self.wte(input_ids)
|
| 609 |
+
position_embeds = self.wpe(position_ids)
|
| 610 |
+
hidden_states = inputs_embeds + position_embeds
|
| 611 |
+
|
| 612 |
+
if token_type_ids is not None:
|
| 613 |
+
token_type_embeds = self.wte(token_type_ids)
|
| 614 |
+
hidden_states = hidden_states + token_type_embeds
|
| 615 |
+
|
| 616 |
+
hidden_states = self.drop(hidden_states)
|
| 617 |
+
|
| 618 |
+
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
| 619 |
+
|
| 620 |
+
if self.gradient_checkpointing and self.training:
|
| 621 |
+
if use_cache:
|
| 622 |
+
logger.warning_once(
|
| 623 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 624 |
+
)
|
| 625 |
+
use_cache = False
|
| 626 |
+
|
| 627 |
+
presents = () if use_cache else None
|
| 628 |
+
all_self_attentions = () if output_attentions else None
|
| 629 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 630 |
+
all_hidden_states = () if output_hidden_states else None
|
| 631 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
| 632 |
+
# Model parallel
|
| 633 |
+
if self.model_parallel:
|
| 634 |
+
torch.cuda.set_device(hidden_states.device)
|
| 635 |
+
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
| 636 |
+
if layer_past is not None:
|
| 637 |
+
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
| 638 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
| 639 |
+
if attention_mask is not None:
|
| 640 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
| 641 |
+
if isinstance(head_mask, torch.Tensor):
|
| 642 |
+
head_mask = head_mask.to(hidden_states.device)
|
| 643 |
+
if output_hidden_states:
|
| 644 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 645 |
+
|
| 646 |
+
if self.gradient_checkpointing and self.training:
|
| 647 |
+
outputs = self._gradient_checkpointing_func(
|
| 648 |
+
block.__call__,
|
| 649 |
+
hidden_states,
|
| 650 |
+
None,
|
| 651 |
+
attention_mask,
|
| 652 |
+
head_mask[i],
|
| 653 |
+
encoder_hidden_states,
|
| 654 |
+
encoder_attention_mask,
|
| 655 |
+
use_cache,
|
| 656 |
+
output_attentions,
|
| 657 |
+
)
|
| 658 |
+
else:
|
| 659 |
+
outputs = block(
|
| 660 |
+
hidden_states,
|
| 661 |
+
layer_past=layer_past,
|
| 662 |
+
attention_mask=attention_mask,
|
| 663 |
+
head_mask=head_mask[i],
|
| 664 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 665 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 666 |
+
use_cache=use_cache,
|
| 667 |
+
output_attentions=output_attentions,
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
hidden_states = outputs[0]
|
| 671 |
+
if use_cache is True:
|
| 672 |
+
presents = presents + (outputs[1],)
|
| 673 |
+
|
| 674 |
+
if output_attentions:
|
| 675 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
| 676 |
+
if self.config.add_cross_attention:
|
| 677 |
+
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
| 678 |
+
|
| 679 |
+
# Model Parallel: If it's the last layer for that device, put things on the next device
|
| 680 |
+
if self.model_parallel:
|
| 681 |
+
for k, v in self.device_map.items():
|
| 682 |
+
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
| 683 |
+
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
| 684 |
+
|
| 685 |
+
hidden_states = self.ln_f(hidden_states)
|
| 686 |
+
|
| 687 |
+
hidden_states = hidden_states.view(output_shape)
|
| 688 |
+
# Add last hidden state
|
| 689 |
+
if output_hidden_states:
|
| 690 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 691 |
+
|
| 692 |
+
if not return_dict:
|
| 693 |
+
return tuple(
|
| 694 |
+
v
|
| 695 |
+
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
| 696 |
+
if v is not None
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 700 |
+
last_hidden_state=hidden_states,
|
| 701 |
+
past_key_values=presents,
|
| 702 |
+
hidden_states=all_hidden_states,
|
| 703 |
+
attentions=all_self_attentions,
|
| 704 |
+
cross_attentions=all_cross_attentions,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
@dataclass
|
| 709 |
+
class DecisionTransformerOutput(ModelOutput):
|
| 710 |
+
"""
|
| 711 |
+
Base class for model's outputs that also contains a pooling of the last hidden states.
|
| 712 |
+
|
| 713 |
+
Args:
|
| 714 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 715 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 716 |
+
state_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, state_dim)`):
|
| 717 |
+
Environment state predictions
|
| 718 |
+
action_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, action_dim)`):
|
| 719 |
+
Model action predictions
|
| 720 |
+
return_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, 1)`):
|
| 721 |
+
Predicted returns for each state
|
| 722 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 723 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 724 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
| 725 |
+
|
| 726 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 727 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 728 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 729 |
+
sequence_length)`.
|
| 730 |
+
|
| 731 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 732 |
+
heads.
|
| 733 |
+
"""
|
| 734 |
+
|
| 735 |
+
state_preds: torch.FloatTensor = None
|
| 736 |
+
action_preds: torch.FloatTensor = None
|
| 737 |
+
return_preds: torch.FloatTensor = None
|
| 738 |
+
hidden_states: torch.FloatTensor = None
|
| 739 |
+
attentions: torch.FloatTensor = None
|
| 740 |
+
last_hidden_state: torch.FloatTensor = None
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
class DecisionTransformerPreTrainedModel(PreTrainedModel):
|
| 744 |
+
"""
|
| 745 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 746 |
+
models.
|
| 747 |
+
"""
|
| 748 |
+
|
| 749 |
+
config_class = DecisionTransformerConfig
|
| 750 |
+
base_model_prefix = "decision_transformer"
|
| 751 |
+
main_input_name = "states"
|
| 752 |
+
supports_gradient_checkpointing = False
|
| 753 |
+
|
| 754 |
+
def _init_weights(self, module):
|
| 755 |
+
"""Initialize the weights"""
|
| 756 |
+
if isinstance(module, nn.Linear):
|
| 757 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 758 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 759 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 760 |
+
if module.bias is not None:
|
| 761 |
+
module.bias.data.zero_()
|
| 762 |
+
elif isinstance(module, nn.Embedding):
|
| 763 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 764 |
+
if module.padding_idx is not None:
|
| 765 |
+
module.weight.data[module.padding_idx].zero_()
|
| 766 |
+
elif isinstance(module, nn.LayerNorm):
|
| 767 |
+
module.bias.data.zero_()
|
| 768 |
+
module.weight.data.fill_(1.0)
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
DECISION_TRANSFORMER_START_DOCSTRING = r"""
|
| 772 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
| 773 |
+
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 774 |
+
behavior.
|
| 775 |
+
|
| 776 |
+
Parameters:
|
| 777 |
+
config ([`~DecisionTransformerConfig`]): Model configuration class with all the parameters of the model.
|
| 778 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 779 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 780 |
+
"""
|
| 781 |
+
|
| 782 |
+
DECISION_TRANSFORMER_INPUTS_DOCSTRING = r"""
|
| 783 |
+
Args:
|
| 784 |
+
states (`torch.FloatTensor` of shape `(batch_size, episode_length, state_dim)`):
|
| 785 |
+
The states for each step in the trajectory
|
| 786 |
+
actions (`torch.FloatTensor` of shape `(batch_size, episode_length, act_dim)`):
|
| 787 |
+
The actions taken by the "expert" policy for the current state, these are masked for auto regressive
|
| 788 |
+
prediction
|
| 789 |
+
rewards (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
|
| 790 |
+
The rewards for each state, action
|
| 791 |
+
returns_to_go (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
|
| 792 |
+
The returns for each state in the trajectory
|
| 793 |
+
timesteps (`torch.LongTensor` of shape `(batch_size, episode_length)`):
|
| 794 |
+
The timestep for each step in the trajectory
|
| 795 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, episode_length)`):
|
| 796 |
+
Masking, used to mask the actions when performing autoregressive prediction
|
| 797 |
+
"""
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
@add_start_docstrings("The Decision Transformer Model", DECISION_TRANSFORMER_START_DOCSTRING)
|
| 801 |
+
class DecisionTransformerModel(DecisionTransformerPreTrainedModel):
|
| 802 |
+
"""
|
| 803 |
+
|
| 804 |
+
The model builds upon the GPT2 architecture to perform autoregressive prediction of actions in an offline RL
|
| 805 |
+
setting. Refer to the paper for more details: https://arxiv.org/abs/2106.01345
|
| 806 |
+
|
| 807 |
+
"""
|
| 808 |
+
|
| 809 |
+
def __init__(self, config):
|
| 810 |
+
super().__init__(config)
|
| 811 |
+
self.config = config
|
| 812 |
+
self.hidden_size = config.hidden_size
|
| 813 |
+
# note: the only difference between this GPT2Model and the default Huggingface version
|
| 814 |
+
# is that the positional embeddings are removed (since we'll add those ourselves)
|
| 815 |
+
self.encoder = DecisionTransformerGPT2Model(config)
|
| 816 |
+
|
| 817 |
+
self.embed_timestep = nn.Embedding(config.max_ep_len, config.hidden_size)
|
| 818 |
+
self.embed_return = torch.nn.Linear(1, config.hidden_size)
|
| 819 |
+
self.embed_state = torch.nn.Linear(config.state_dim, config.hidden_size)
|
| 820 |
+
self.embed_action = torch.nn.Linear(config.act_dim, config.hidden_size)
|
| 821 |
+
|
| 822 |
+
self.embed_ln = nn.LayerNorm(config.hidden_size)
|
| 823 |
+
|
| 824 |
+
# note: we don't predict states or returns for the paper
|
| 825 |
+
self.predict_state = torch.nn.Linear(config.hidden_size, config.state_dim)
|
| 826 |
+
self.predict_action = nn.Sequential(
|
| 827 |
+
*([nn.Linear(config.hidden_size, config.act_dim)] + ([nn.Tanh()] if config.action_tanh else []))
|
| 828 |
+
)
|
| 829 |
+
self.predict_return = torch.nn.Linear(config.hidden_size, 1)
|
| 830 |
+
|
| 831 |
+
# Initialize weights and apply final processing
|
| 832 |
+
self.post_init()
|
| 833 |
+
|
| 834 |
+
@add_start_docstrings_to_model_forward(DECISION_TRANSFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 835 |
+
@replace_return_docstrings(output_type=DecisionTransformerOutput, config_class=_CONFIG_FOR_DOC)
|
| 836 |
+
def forward(
|
| 837 |
+
self,
|
| 838 |
+
states: Optional[torch.FloatTensor] = None,
|
| 839 |
+
actions: Optional[torch.FloatTensor] = None,
|
| 840 |
+
rewards: Optional[torch.FloatTensor] = None,
|
| 841 |
+
returns_to_go: Optional[torch.FloatTensor] = None,
|
| 842 |
+
timesteps: Optional[torch.LongTensor] = None,
|
| 843 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 844 |
+
output_hidden_states: Optional[bool] = None,
|
| 845 |
+
output_attentions: Optional[bool] = None,
|
| 846 |
+
return_dict: Optional[bool] = None,
|
| 847 |
+
) -> Union[Tuple[torch.FloatTensor], DecisionTransformerOutput]:
|
| 848 |
+
r"""
|
| 849 |
+
Returns:
|
| 850 |
+
|
| 851 |
+
Examples:
|
| 852 |
+
|
| 853 |
+
```python
|
| 854 |
+
>>> from transformers import DecisionTransformerModel
|
| 855 |
+
>>> import torch
|
| 856 |
+
|
| 857 |
+
>>> model = DecisionTransformerModel.from_pretrained("edbeeching/decision-transformer-gym-hopper-medium")
|
| 858 |
+
>>> # evaluation
|
| 859 |
+
>>> model = model.to(device)
|
| 860 |
+
>>> model.eval()
|
| 861 |
+
|
| 862 |
+
>>> env = gym.make("Hopper-v3")
|
| 863 |
+
>>> state_dim = env.observation_space.shape[0]
|
| 864 |
+
>>> act_dim = env.action_space.shape[0]
|
| 865 |
+
|
| 866 |
+
>>> state = env.reset()
|
| 867 |
+
>>> states = torch.from_numpy(state).reshape(1, 1, state_dim).to(device=device, dtype=torch.float32)
|
| 868 |
+
>>> actions = torch.zeros((1, 1, act_dim), device=device, dtype=torch.float32)
|
| 869 |
+
>>> rewards = torch.zeros(1, 1, device=device, dtype=torch.float32)
|
| 870 |
+
>>> target_return = torch.tensor(TARGET_RETURN, dtype=torch.float32).reshape(1, 1)
|
| 871 |
+
>>> timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
|
| 872 |
+
>>> attention_mask = torch.zeros(1, 1, device=device, dtype=torch.float32)
|
| 873 |
+
|
| 874 |
+
>>> # forward pass
|
| 875 |
+
>>> with torch.no_grad():
|
| 876 |
+
... state_preds, action_preds, return_preds = model(
|
| 877 |
+
... states=states,
|
| 878 |
+
... actions=actions,
|
| 879 |
+
... rewards=rewards,
|
| 880 |
+
... returns_to_go=target_return,
|
| 881 |
+
... timesteps=timesteps,
|
| 882 |
+
... attention_mask=attention_mask,
|
| 883 |
+
... return_dict=False,
|
| 884 |
+
... )
|
| 885 |
+
```"""
|
| 886 |
+
|
| 887 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 888 |
+
output_hidden_states = (
|
| 889 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 890 |
+
)
|
| 891 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 892 |
+
|
| 893 |
+
batch_size, seq_length = states.shape[0], states.shape[1]
|
| 894 |
+
|
| 895 |
+
if attention_mask is None:
|
| 896 |
+
# attention mask for GPT: 1 if can be attended to, 0 if not
|
| 897 |
+
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
|
| 898 |
+
|
| 899 |
+
# embed each modality with a different head
|
| 900 |
+
state_embeddings = self.embed_state(states)
|
| 901 |
+
action_embeddings = self.embed_action(actions)
|
| 902 |
+
returns_embeddings = self.embed_return(returns_to_go)
|
| 903 |
+
time_embeddings = self.embed_timestep(timesteps)
|
| 904 |
+
|
| 905 |
+
# time embeddings are treated similar to positional embeddings
|
| 906 |
+
state_embeddings = state_embeddings + time_embeddings
|
| 907 |
+
action_embeddings = action_embeddings + time_embeddings
|
| 908 |
+
returns_embeddings = returns_embeddings + time_embeddings
|
| 909 |
+
|
| 910 |
+
# this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
|
| 911 |
+
# which works nice in an autoregressive sense since states predict actions
|
| 912 |
+
stacked_inputs = (
|
| 913 |
+
torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1)
|
| 914 |
+
.permute(0, 2, 1, 3)
|
| 915 |
+
.reshape(batch_size, 3 * seq_length, self.hidden_size)
|
| 916 |
+
)
|
| 917 |
+
stacked_inputs = self.embed_ln(stacked_inputs)
|
| 918 |
+
|
| 919 |
+
# to make the attention mask fit the stacked inputs, have to stack it as well
|
| 920 |
+
stacked_attention_mask = (
|
| 921 |
+
torch.stack((attention_mask, attention_mask, attention_mask), dim=1)
|
| 922 |
+
.permute(0, 2, 1)
|
| 923 |
+
.reshape(batch_size, 3 * seq_length)
|
| 924 |
+
)
|
| 925 |
+
device = stacked_inputs.device
|
| 926 |
+
# we feed in the input embeddings (not word indices as in NLP) to the model
|
| 927 |
+
encoder_outputs = self.encoder(
|
| 928 |
+
inputs_embeds=stacked_inputs,
|
| 929 |
+
attention_mask=stacked_attention_mask,
|
| 930 |
+
position_ids=torch.zeros(stacked_attention_mask.shape, device=device, dtype=torch.long),
|
| 931 |
+
output_attentions=output_attentions,
|
| 932 |
+
output_hidden_states=output_hidden_states,
|
| 933 |
+
return_dict=return_dict,
|
| 934 |
+
)
|
| 935 |
+
x = encoder_outputs[0]
|
| 936 |
+
|
| 937 |
+
# reshape x so that the second dimension corresponds to the original
|
| 938 |
+
# returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
|
| 939 |
+
x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)
|
| 940 |
+
|
| 941 |
+
# get predictions
|
| 942 |
+
return_preds = self.predict_return(x[:, 2]) # predict next return given state and action
|
| 943 |
+
state_preds = self.predict_state(x[:, 2]) # predict next state given state and action
|
| 944 |
+
action_preds = self.predict_action(x[:, 1]) # predict next action given state
|
| 945 |
+
if not return_dict:
|
| 946 |
+
return (state_preds, action_preds, return_preds)
|
| 947 |
+
|
| 948 |
+
return DecisionTransformerOutput(
|
| 949 |
+
last_hidden_state=encoder_outputs.last_hidden_state,
|
| 950 |
+
state_preds=state_preds,
|
| 951 |
+
action_preds=action_preds,
|
| 952 |
+
return_preds=return_preds,
|
| 953 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 954 |
+
attentions=encoder_outputs.attentions,
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
|
| 958 |
+
__all__ = [
|
| 959 |
+
"DecisionTransformerGPT2Model",
|
| 960 |
+
"DecisionTransformerGPT2PreTrainedModel",
|
| 961 |
+
"DecisionTransformerModel",
|
| 962 |
+
"DecisionTransformerPreTrainedModel",
|
| 963 |
+
]
|
.venv/lib/python3.11/site-packages/transformers/models/focalnet/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (772 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .tokenization_gpt_sw3 import *
|
| 22 |
+
else:
|
| 23 |
+
import sys
|
| 24 |
+
|
| 25 |
+
_file = globals()["__file__"]
|
| 26 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (730 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/tokenization_gpt_sw3.cpython-311.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/tokenization_gpt_sw3.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""The tokenizer used by the GPT-SW3 models."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import unicodedata
|
| 6 |
+
from shutil import copyfile
|
| 7 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import sentencepiece as spm
|
| 10 |
+
|
| 11 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 12 |
+
from ...utils import is_torch_available, logging
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
if is_torch_available():
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = logging.get_logger(__name__)
|
| 20 |
+
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class GPTSw3Tokenizer(PreTrainedTokenizer):
|
| 24 |
+
"""
|
| 25 |
+
Construct an GPTSw3 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
|
| 26 |
+
|
| 27 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 28 |
+
this superclass for more information regarding those methods.
|
| 29 |
+
|
| 30 |
+
Example usage:
|
| 31 |
+
```python
|
| 32 |
+
>>> from transformers import GPTSw3Tokenizer
|
| 33 |
+
|
| 34 |
+
>>> tokenizer = GPTSw3Tokenizer.from_pretrained("AI-Sweden-Models/gpt-sw3-126m")
|
| 35 |
+
>>> tokenizer("Svenska är kul!")["input_ids"]
|
| 36 |
+
[1814, 377, 3617, 63504]
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
vocab_file (`str`):
|
| 41 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
|
| 42 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
| 43 |
+
do_lower_case (`bool`, *optional*, defaults to `False`):
|
| 44 |
+
Whether or not to lowercase the input when tokenizing.
|
| 45 |
+
remove_space (`bool`, *optional*, defaults to `False`):
|
| 46 |
+
Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
|
| 47 |
+
keep_accents (`bool`, *optional*, defaults to `False`):
|
| 48 |
+
Whether or not to keep accents when tokenizing.
|
| 49 |
+
pad_token (`str`, *optional*):
|
| 50 |
+
The token used for padding, for example when batching sequences of different lengths. If not provided, will
|
| 51 |
+
default to '<pad>' or '<unk>' depending on model size.
|
| 52 |
+
unk_token (`str`, *optional*):
|
| 53 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 54 |
+
token instead. If not provided, will default to '<unk>'.
|
| 55 |
+
eos_token (`str`, *optional*):
|
| 56 |
+
The end of sequence token seen during pretraining. If not provided, will default to '<|endoftext|>'
|
| 57 |
+
bos_token (`str`, *optional*):
|
| 58 |
+
The beginning of sequence token that can be used for downstream task, was not seen during pretraining. If
|
| 59 |
+
not provided, will default to '<s>' or '<|endoftext|>', depending on model size.
|
| 60 |
+
sp_model_kwargs (`dict`, *optional*):
|
| 61 |
+
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
| 62 |
+
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
| 63 |
+
to set:
|
| 64 |
+
|
| 65 |
+
- `enable_sampling`: Enable subword regularization.
|
| 66 |
+
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
| 67 |
+
|
| 68 |
+
- `nbest_size = {0,1}`: No sampling is performed.
|
| 69 |
+
- `nbest_size > 1`: samples from the nbest_size results.
|
| 70 |
+
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
| 71 |
+
using forward-filtering-and-backward-sampling algorithm.
|
| 72 |
+
|
| 73 |
+
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
| 74 |
+
BPE-dropout.
|
| 75 |
+
|
| 76 |
+
Attributes:
|
| 77 |
+
sp_model (`SentencePieceProcessor`):
|
| 78 |
+
The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
|
| 79 |
+
whitespaces (`set`):
|
| 80 |
+
The whitespaces that are replaced in the whitespace normalization in preprocessing.
|
| 81 |
+
non_printing_characters_re (`Pattern`):
|
| 82 |
+
The compiled regular expression to remove non-printing characters in preprocessing.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 86 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
vocab_file,
|
| 91 |
+
do_lower_case=False,
|
| 92 |
+
remove_space=False,
|
| 93 |
+
keep_accents=False,
|
| 94 |
+
pad_token=None,
|
| 95 |
+
unk_token=None,
|
| 96 |
+
eos_token=None,
|
| 97 |
+
bos_token=None,
|
| 98 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
| 99 |
+
**kwargs,
|
| 100 |
+
) -> None:
|
| 101 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 102 |
+
|
| 103 |
+
name_or_path = kwargs.get("name_or_path")
|
| 104 |
+
if name_or_path is None:
|
| 105 |
+
logger.warning(
|
| 106 |
+
"name_or_path not provided, will work for all GPTSw3 models except gpt-sw3-7b,"
|
| 107 |
+
" you are testing the model, this can safely be ignored"
|
| 108 |
+
)
|
| 109 |
+
name_or_path = "None"
|
| 110 |
+
|
| 111 |
+
# Default definitions for our 2 tokenizer versions, with None-checks to enable proper testing
|
| 112 |
+
eos_token = "<|endoftext|>" if eos_token is None else eos_token
|
| 113 |
+
unk_token = "<unk>" if unk_token is None else unk_token
|
| 114 |
+
if "gpt-sw3-7b" in name_or_path:
|
| 115 |
+
pad_token = unk_token if pad_token is None else pad_token
|
| 116 |
+
bos_token = eos_token if bos_token is None else bos_token
|
| 117 |
+
else:
|
| 118 |
+
pad_token = "<pad>" if pad_token is None else pad_token
|
| 119 |
+
bos_token = "<s>" if bos_token is None else bos_token
|
| 120 |
+
|
| 121 |
+
self.do_lower_case = do_lower_case
|
| 122 |
+
self.remove_space = remove_space
|
| 123 |
+
self.keep_accents = keep_accents
|
| 124 |
+
self.vocab_file = vocab_file
|
| 125 |
+
|
| 126 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 127 |
+
self.sp_model.Load(vocab_file)
|
| 128 |
+
|
| 129 |
+
# Used for whitespace normalization in input texts
|
| 130 |
+
# fmt : off
|
| 131 |
+
self.whitespaces = {" ", " ", " ", " ", " ", " ", " ", " ", " ", " ", "", ""}
|
| 132 |
+
# fmt : on
|
| 133 |
+
|
| 134 |
+
# Regular expression to remove non-printing characters (e.g. some unicode control chars) in preprocessing
|
| 135 |
+
self.non_printing_characters_re = re.compile(
|
| 136 |
+
f"[{''.join(map(chr, list(range(0, 9)) + list(range(11, 32)) + list(range(127, 160)) + [160, 173, 8203]))}]"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
super().__init__(
|
| 140 |
+
do_lower_case=do_lower_case,
|
| 141 |
+
remove_space=remove_space,
|
| 142 |
+
keep_accents=keep_accents,
|
| 143 |
+
bos_token=bos_token,
|
| 144 |
+
eos_token=eos_token,
|
| 145 |
+
unk_token=unk_token,
|
| 146 |
+
pad_token=pad_token,
|
| 147 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
| 148 |
+
**kwargs,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.__getstate__
|
| 152 |
+
def __getstate__(self):
|
| 153 |
+
state = self.__dict__.copy()
|
| 154 |
+
state["sp_model"] = None
|
| 155 |
+
return state
|
| 156 |
+
|
| 157 |
+
# Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.__setstate__
|
| 158 |
+
def __setstate__(self, d):
|
| 159 |
+
self.__dict__ = d
|
| 160 |
+
|
| 161 |
+
# for backward compatibility
|
| 162 |
+
if not hasattr(self, "sp_model_kwargs"):
|
| 163 |
+
self.sp_model_kwargs = {}
|
| 164 |
+
|
| 165 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 166 |
+
self.sp_model.Load(self.vocab_file)
|
| 167 |
+
|
| 168 |
+
@property
|
| 169 |
+
# Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.vocab_size
|
| 170 |
+
def vocab_size(self) -> int:
|
| 171 |
+
return len(self.sp_model)
|
| 172 |
+
|
| 173 |
+
def preprocess_text(self, text: str) -> str:
|
| 174 |
+
"""
|
| 175 |
+
Returns the preprocessed text. This procedure is identical to what was used when training the tokenizer.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
# Remove non-printing characters
|
| 179 |
+
text = self.non_printing_characters_re.sub("", text)
|
| 180 |
+
|
| 181 |
+
# Normalize whitespaces
|
| 182 |
+
text = "".join([char if char not in self.whitespaces else " " for char in text])
|
| 183 |
+
|
| 184 |
+
# NFC Unicode normalization
|
| 185 |
+
text = unicodedata.normalize("NFC", text)
|
| 186 |
+
return text
|
| 187 |
+
|
| 188 |
+
def _tokenize(self, text: str, **kwargs) -> List[str]:
|
| 189 |
+
text = self.preprocess_text(text)
|
| 190 |
+
return self.sp_model.encode(text, out_type=str)
|
| 191 |
+
|
| 192 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 193 |
+
"""Converts a token (str) to an id (int) using the vocab."""
|
| 194 |
+
return self.sp_model.PieceToId(token)
|
| 195 |
+
|
| 196 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 197 |
+
"""Converts an index (int) to a token (str) using the vocab."""
|
| 198 |
+
return self.sp_model.IdToPiece(index)
|
| 199 |
+
|
| 200 |
+
@staticmethod
|
| 201 |
+
def clean_up_tokenization(out_string: str) -> str:
|
| 202 |
+
"""Returns the input string, this function is overridden to remove the default clean up."""
|
| 203 |
+
return out_string
|
| 204 |
+
|
| 205 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 206 |
+
"""Converts a sequence of tokens (strings) to a single string. Special tokens remain intact."""
|
| 207 |
+
current_sub_tokens = []
|
| 208 |
+
out_string = ""
|
| 209 |
+
prev_is_special = False
|
| 210 |
+
for token in tokens:
|
| 211 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
| 212 |
+
if token in self.all_special_tokens:
|
| 213 |
+
# TODO: Check if this is needed, as it ensures that decode(encode(doc)) != doc by adding extra whitespace in the decoded document
|
| 214 |
+
if not prev_is_special:
|
| 215 |
+
out_string += " "
|
| 216 |
+
|
| 217 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
| 218 |
+
prev_is_special = True
|
| 219 |
+
current_sub_tokens = []
|
| 220 |
+
else:
|
| 221 |
+
current_sub_tokens.append(token)
|
| 222 |
+
prev_is_special = False
|
| 223 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 224 |
+
|
| 225 |
+
return out_string
|
| 226 |
+
|
| 227 |
+
# Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.get_vocab
|
| 228 |
+
def get_vocab(self) -> Dict[str, int]:
|
| 229 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 230 |
+
vocab.update(self.added_tokens_encoder)
|
| 231 |
+
return vocab
|
| 232 |
+
|
| 233 |
+
# Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.save_vocabulary
|
| 234 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 235 |
+
if not os.path.isdir(save_directory):
|
| 236 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 237 |
+
return
|
| 238 |
+
out_vocab_file = os.path.join(
|
| 239 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 243 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 244 |
+
elif not os.path.isfile(self.vocab_file):
|
| 245 |
+
with open(out_vocab_file, "wb") as fi:
|
| 246 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 247 |
+
fi.write(content_spiece_model)
|
| 248 |
+
|
| 249 |
+
return (out_vocab_file,)
|
| 250 |
+
|
| 251 |
+
def encode_fast(
|
| 252 |
+
self, text: Union[str, List[str]], return_tensors: Union[str, bool] = False
|
| 253 |
+
) -> Union[List[int], List[List[int]], "torch.Tensor"]:
|
| 254 |
+
"""
|
| 255 |
+
Encodes a text or batch of texts to token ids using preprocessing and the raw SP tokenizer. This has reduced
|
| 256 |
+
functionality but is often much faster.
|
| 257 |
+
|
| 258 |
+
Does NOT handle special tokens correctly, these can manually be added as ids afterwards.
|
| 259 |
+
|
| 260 |
+
Does NOT support padding, these can manually be added as ids afterwards.
|
| 261 |
+
|
| 262 |
+
Use default HuggingFace tokenization methods for full functionality.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
text (`str` or `List[str]`): One or several text(s) to convert to token ids.
|
| 266 |
+
return_tensors (`str` or `bool`): Returns PyTorch tensors if set to True or "pt"
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
`List[int]`, `List[List[int]]`, or `torch.Tensor`: The encoded text(s) as token ids.
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
if isinstance(text, str):
|
| 273 |
+
text = self.preprocess_text(text)
|
| 274 |
+
token_ids = self.sp_model.encode(text)
|
| 275 |
+
else:
|
| 276 |
+
text = [self.preprocess_text(t) for t in text]
|
| 277 |
+
token_ids = self.sp_model.encode(text)
|
| 278 |
+
|
| 279 |
+
if return_tensors is True or return_tensors == "pt":
|
| 280 |
+
token_ids = torch.tensor(token_ids)
|
| 281 |
+
|
| 282 |
+
return token_ids
|
| 283 |
+
|
| 284 |
+
def decode_fast(self, token_ids: Union[int, List[int]]) -> str:
|
| 285 |
+
"""
|
| 286 |
+
Encodes a text or batch of texts to token ids using preprocessing and the raw SP tokenizer. This has reduced
|
| 287 |
+
functionality but is often much faster.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
token_ids (`int` or `List[int]`): Encoded token or text as token id(s).
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
`str`: Decoded text
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
return self.sp_model.decode(token_ids)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
__all__ = ["GPTSw3Tokenizer"]
|
.venv/lib/python3.11/site-packages/transformers/models/musicgen/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_musicgen import *
|
| 22 |
+
from .modeling_musicgen import *
|
| 23 |
+
from .processing_musicgen import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (813 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/configuration_musicgen.cpython-311.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/processing_musicgen.cpython-311.pyc
ADDED
|
Binary file (6.39 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/musicgen/configuration_musicgen.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""MusicGen model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
from ..auto.configuration_auto import AutoConfig
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class MusicgenDecoderConfig(PretrainedConfig):
|
| 26 |
+
r"""
|
| 27 |
+
This is the configuration class to store the configuration of an [`MusicgenDecoder`]. It is used to instantiate a
|
| 28 |
+
MusicGen decoder according to the specified arguments, defining the model architecture. Instantiating a
|
| 29 |
+
configuration with the defaults will yield a similar configuration to that of the MusicGen
|
| 30 |
+
[facebook/musicgen-small](https://huggingface.co/facebook/musicgen-small) architecture.
|
| 31 |
+
|
| 32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 33 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
vocab_size (`int`, *optional*, defaults to 2048):
|
| 38 |
+
Vocabulary size of the MusicgenDecoder model. Defines the number of different tokens that can be
|
| 39 |
+
represented by the `inputs_ids` passed when calling [`MusicgenDecoder`].
|
| 40 |
+
hidden_size (`int`, *optional*, defaults to 1024):
|
| 41 |
+
Dimensionality of the layers and the pooler layer.
|
| 42 |
+
num_hidden_layers (`int`, *optional*, defaults to 24):
|
| 43 |
+
Number of decoder layers.
|
| 44 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 45 |
+
Number of attention heads for each attention layer in the Transformer block.
|
| 46 |
+
ffn_dim (`int`, *optional*, defaults to 4096):
|
| 47 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
|
| 48 |
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 49 |
+
The non-linear activation function (function or string) in the decoder and pooler. If string, `"gelu"`,
|
| 50 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 51 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
| 52 |
+
The dropout probability for all fully connected layers in the embeddings, text_encoder, and pooler.
|
| 53 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 54 |
+
The dropout ratio for the attention probabilities.
|
| 55 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
| 56 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 57 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 58 |
+
The maximum sequence length that this model might ever be used with. Typically, set this to something large
|
| 59 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 60 |
+
initializer_factor (`float`, *optional*, defaults to 0.02):
|
| 61 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 62 |
+
layerdrop (`float`, *optional*, defaults to 0.0):
|
| 63 |
+
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
| 64 |
+
for more details.
|
| 65 |
+
scale_embedding (`bool`, *optional*, defaults to `False`):
|
| 66 |
+
Scale embeddings by diving by sqrt(hidden_size).
|
| 67 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 68 |
+
Whether the model should return the last key/values attentions (not used by all models)
|
| 69 |
+
num_codebooks (`int`, *optional*, defaults to 4):
|
| 70 |
+
The number of parallel codebooks forwarded to the model.
|
| 71 |
+
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
| 72 |
+
Whether input and output word embeddings should be tied.
|
| 73 |
+
audio_channels (`int`, *optional*, defaults to 1
|
| 74 |
+
Number of channels in the audio data. Either 1 for mono or 2 for stereo. Stereo models generate a separate
|
| 75 |
+
audio stream for the left/right output channels. Mono models generate a single audio stream output.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
model_type = "musicgen_decoder"
|
| 79 |
+
base_config_key = "decoder_config"
|
| 80 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
vocab_size=2048,
|
| 85 |
+
max_position_embeddings=2048,
|
| 86 |
+
num_hidden_layers=24,
|
| 87 |
+
ffn_dim=4096,
|
| 88 |
+
num_attention_heads=16,
|
| 89 |
+
layerdrop=0.0,
|
| 90 |
+
use_cache=True,
|
| 91 |
+
activation_function="gelu",
|
| 92 |
+
hidden_size=1024,
|
| 93 |
+
dropout=0.1,
|
| 94 |
+
attention_dropout=0.0,
|
| 95 |
+
activation_dropout=0.0,
|
| 96 |
+
initializer_factor=0.02,
|
| 97 |
+
scale_embedding=False,
|
| 98 |
+
num_codebooks=4,
|
| 99 |
+
audio_channels=1,
|
| 100 |
+
pad_token_id=2048,
|
| 101 |
+
bos_token_id=2048,
|
| 102 |
+
eos_token_id=None,
|
| 103 |
+
tie_word_embeddings=False,
|
| 104 |
+
**kwargs,
|
| 105 |
+
):
|
| 106 |
+
self.vocab_size = vocab_size
|
| 107 |
+
self.max_position_embeddings = max_position_embeddings
|
| 108 |
+
self.hidden_size = hidden_size
|
| 109 |
+
self.ffn_dim = ffn_dim
|
| 110 |
+
self.num_hidden_layers = num_hidden_layers
|
| 111 |
+
self.num_attention_heads = num_attention_heads
|
| 112 |
+
self.dropout = dropout
|
| 113 |
+
self.attention_dropout = attention_dropout
|
| 114 |
+
self.activation_dropout = activation_dropout
|
| 115 |
+
self.activation_function = activation_function
|
| 116 |
+
self.initializer_factor = initializer_factor
|
| 117 |
+
self.layerdrop = layerdrop
|
| 118 |
+
self.use_cache = use_cache
|
| 119 |
+
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
| 120 |
+
self.num_codebooks = num_codebooks
|
| 121 |
+
|
| 122 |
+
if audio_channels not in [1, 2]:
|
| 123 |
+
raise ValueError(f"Expected 1 (mono) or 2 (stereo) audio channels, got {audio_channels} channels.")
|
| 124 |
+
self.audio_channels = audio_channels
|
| 125 |
+
|
| 126 |
+
super().__init__(
|
| 127 |
+
pad_token_id=pad_token_id,
|
| 128 |
+
bos_token_id=bos_token_id,
|
| 129 |
+
eos_token_id=eos_token_id,
|
| 130 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 131 |
+
**kwargs,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class MusicgenConfig(PretrainedConfig):
|
| 136 |
+
r"""
|
| 137 |
+
This is the configuration class to store the configuration of a [`MusicgenModel`]. It is used to instantiate a
|
| 138 |
+
MusicGen model according to the specified arguments, defining the text encoder, audio encoder and MusicGen decoder
|
| 139 |
+
configs.
|
| 140 |
+
|
| 141 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 142 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
kwargs (*optional*):
|
| 146 |
+
Dictionary of keyword arguments. Notably:
|
| 147 |
+
|
| 148 |
+
- **text_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
|
| 149 |
+
defines the text encoder config.
|
| 150 |
+
- **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
|
| 151 |
+
defines the audio encoder config.
|
| 152 |
+
- **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
|
| 153 |
+
the decoder config.
|
| 154 |
+
|
| 155 |
+
Example:
|
| 156 |
+
|
| 157 |
+
```python
|
| 158 |
+
>>> from transformers import (
|
| 159 |
+
... MusicgenConfig,
|
| 160 |
+
... MusicgenDecoderConfig,
|
| 161 |
+
... T5Config,
|
| 162 |
+
... EncodecConfig,
|
| 163 |
+
... MusicgenForConditionalGeneration,
|
| 164 |
+
... )
|
| 165 |
+
|
| 166 |
+
>>> # Initializing text encoder, audio encoder, and decoder model configurations
|
| 167 |
+
>>> text_encoder_config = T5Config()
|
| 168 |
+
>>> audio_encoder_config = EncodecConfig()
|
| 169 |
+
>>> decoder_config = MusicgenDecoderConfig()
|
| 170 |
+
|
| 171 |
+
>>> configuration = MusicgenConfig.from_sub_models_config(
|
| 172 |
+
... text_encoder_config, audio_encoder_config, decoder_config
|
| 173 |
+
... )
|
| 174 |
+
|
| 175 |
+
>>> # Initializing a MusicgenForConditionalGeneration (with random weights) from the facebook/musicgen-small style configuration
|
| 176 |
+
>>> model = MusicgenForConditionalGeneration(configuration)
|
| 177 |
+
|
| 178 |
+
>>> # Accessing the model configuration
|
| 179 |
+
>>> configuration = model.config
|
| 180 |
+
>>> config_text_encoder = model.config.text_encoder
|
| 181 |
+
>>> config_audio_encoder = model.config.audio_encoder
|
| 182 |
+
>>> config_decoder = model.config.decoder
|
| 183 |
+
|
| 184 |
+
>>> # Saving the model, including its configuration
|
| 185 |
+
>>> model.save_pretrained("musicgen-model")
|
| 186 |
+
|
| 187 |
+
>>> # loading model and config from pretrained folder
|
| 188 |
+
>>> musicgen_config = MusicgenConfig.from_pretrained("musicgen-model")
|
| 189 |
+
>>> model = MusicgenForConditionalGeneration.from_pretrained("musicgen-model", config=musicgen_config)
|
| 190 |
+
```"""
|
| 191 |
+
|
| 192 |
+
model_type = "musicgen"
|
| 193 |
+
sub_configs = {
|
| 194 |
+
"text_encoder": AutoConfig,
|
| 195 |
+
"audio_encoder": AutoConfig,
|
| 196 |
+
"decoder": MusicgenDecoderConfig,
|
| 197 |
+
}
|
| 198 |
+
is_composition = True
|
| 199 |
+
|
| 200 |
+
def __init__(self, **kwargs):
|
| 201 |
+
super().__init__(**kwargs)
|
| 202 |
+
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
|
| 203 |
+
raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
|
| 204 |
+
|
| 205 |
+
text_encoder_config = kwargs.pop("text_encoder")
|
| 206 |
+
text_encoder_model_type = text_encoder_config.pop("model_type")
|
| 207 |
+
|
| 208 |
+
audio_encoder_config = kwargs.pop("audio_encoder")
|
| 209 |
+
audio_encoder_model_type = audio_encoder_config.pop("model_type")
|
| 210 |
+
|
| 211 |
+
decoder_config = kwargs.pop("decoder")
|
| 212 |
+
|
| 213 |
+
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
|
| 214 |
+
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
|
| 215 |
+
self.decoder = MusicgenDecoderConfig(**decoder_config)
|
| 216 |
+
self.is_encoder_decoder = True
|
| 217 |
+
|
| 218 |
+
@classmethod
|
| 219 |
+
def from_sub_models_config(
|
| 220 |
+
cls,
|
| 221 |
+
text_encoder_config: PretrainedConfig,
|
| 222 |
+
audio_encoder_config: PretrainedConfig,
|
| 223 |
+
decoder_config: MusicgenDecoderConfig,
|
| 224 |
+
**kwargs,
|
| 225 |
+
):
|
| 226 |
+
r"""
|
| 227 |
+
Instantiate a [`MusicgenConfig`] (or a derived class) from text encoder, audio encoder and decoder
|
| 228 |
+
configurations.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
[`MusicgenConfig`]: An instance of a configuration object
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
return cls(
|
| 235 |
+
text_encoder=text_encoder_config.to_dict(),
|
| 236 |
+
audio_encoder=audio_encoder_config.to_dict(),
|
| 237 |
+
decoder=decoder_config.to_dict(),
|
| 238 |
+
**kwargs,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
@property
|
| 242 |
+
# This is a property because you might want to change the codec model on the fly
|
| 243 |
+
def sampling_rate(self):
|
| 244 |
+
return self.audio_encoder.sampling_rate
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
__all__ = ["MusicgenConfig", "MusicgenDecoderConfig"]
|
.venv/lib/python3.11/site-packages/transformers/models/musicgen/modeling_musicgen.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/musicgen/processing_musicgen.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Text/audio processor class for MusicGen
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from typing import List, Optional
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
from ...processing_utils import ProcessorMixin
|
| 24 |
+
from ...utils import to_numpy
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class MusicgenProcessor(ProcessorMixin):
|
| 28 |
+
r"""
|
| 29 |
+
Constructs a MusicGen processor which wraps an EnCodec feature extractor and a T5 tokenizer into a single processor
|
| 30 |
+
class.
|
| 31 |
+
|
| 32 |
+
[`MusicgenProcessor`] offers all the functionalities of [`EncodecFeatureExtractor`] and [`TTokenizer`]. See
|
| 33 |
+
[`~MusicgenProcessor.__call__`] and [`~MusicgenProcessor.decode`] for more information.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
feature_extractor (`EncodecFeatureExtractor`):
|
| 37 |
+
An instance of [`EncodecFeatureExtractor`]. The feature extractor is a required input.
|
| 38 |
+
tokenizer (`T5Tokenizer`):
|
| 39 |
+
An instance of [`T5Tokenizer`]. The tokenizer is a required input.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
feature_extractor_class = "EncodecFeatureExtractor"
|
| 43 |
+
tokenizer_class = ("T5Tokenizer", "T5TokenizerFast")
|
| 44 |
+
|
| 45 |
+
def __init__(self, feature_extractor, tokenizer):
|
| 46 |
+
super().__init__(feature_extractor, tokenizer)
|
| 47 |
+
self.current_processor = self.feature_extractor
|
| 48 |
+
self._in_target_context_manager = False
|
| 49 |
+
|
| 50 |
+
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
|
| 51 |
+
return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
|
| 52 |
+
|
| 53 |
+
def __call__(self, *args, **kwargs):
|
| 54 |
+
"""
|
| 55 |
+
Forwards the `audio` argument to EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`] and the `text`
|
| 56 |
+
argument to [`~T5Tokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
|
| 57 |
+
information.
|
| 58 |
+
"""
|
| 59 |
+
# For backward compatibility
|
| 60 |
+
if self._in_target_context_manager:
|
| 61 |
+
return self.current_processor(*args, **kwargs)
|
| 62 |
+
|
| 63 |
+
audio = kwargs.pop("audio", None)
|
| 64 |
+
sampling_rate = kwargs.pop("sampling_rate", None)
|
| 65 |
+
text = kwargs.pop("text", None)
|
| 66 |
+
if len(args) > 0:
|
| 67 |
+
audio = args[0]
|
| 68 |
+
args = args[1:]
|
| 69 |
+
|
| 70 |
+
if audio is None and text is None:
|
| 71 |
+
raise ValueError("You need to specify either an `audio` or `text` input to process.")
|
| 72 |
+
|
| 73 |
+
if text is not None:
|
| 74 |
+
inputs = self.tokenizer(text, **kwargs)
|
| 75 |
+
|
| 76 |
+
if audio is not None:
|
| 77 |
+
audio_inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
|
| 78 |
+
|
| 79 |
+
if audio is None:
|
| 80 |
+
return inputs
|
| 81 |
+
|
| 82 |
+
elif text is None:
|
| 83 |
+
return audio_inputs
|
| 84 |
+
|
| 85 |
+
else:
|
| 86 |
+
inputs["input_values"] = audio_inputs["input_values"]
|
| 87 |
+
if "padding_mask" in audio_inputs:
|
| 88 |
+
inputs["padding_mask"] = audio_inputs["padding_mask"]
|
| 89 |
+
return inputs
|
| 90 |
+
|
| 91 |
+
def batch_decode(self, *args, **kwargs):
|
| 92 |
+
"""
|
| 93 |
+
This method is used to decode either batches of audio outputs from the MusicGen model, or batches of token ids
|
| 94 |
+
from the tokenizer. In the case of decoding token ids, this method forwards all its arguments to T5Tokenizer's
|
| 95 |
+
[`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information.
|
| 96 |
+
"""
|
| 97 |
+
audio_values = kwargs.pop("audio", None)
|
| 98 |
+
padding_mask = kwargs.pop("padding_mask", None)
|
| 99 |
+
|
| 100 |
+
if len(args) > 0:
|
| 101 |
+
audio_values = args[0]
|
| 102 |
+
args = args[1:]
|
| 103 |
+
|
| 104 |
+
if audio_values is not None:
|
| 105 |
+
return self._decode_audio(audio_values, padding_mask=padding_mask)
|
| 106 |
+
else:
|
| 107 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 108 |
+
|
| 109 |
+
def decode(self, *args, **kwargs):
|
| 110 |
+
"""
|
| 111 |
+
This method forwards all its arguments to T5Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
|
| 112 |
+
docstring of this method for more information.
|
| 113 |
+
"""
|
| 114 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 115 |
+
|
| 116 |
+
def _decode_audio(self, audio_values, padding_mask: Optional = None) -> List[np.ndarray]:
|
| 117 |
+
"""
|
| 118 |
+
This method strips any padding from the audio values to return a list of numpy audio arrays.
|
| 119 |
+
"""
|
| 120 |
+
audio_values = to_numpy(audio_values)
|
| 121 |
+
bsz, channels, seq_len = audio_values.shape
|
| 122 |
+
|
| 123 |
+
if padding_mask is None:
|
| 124 |
+
return list(audio_values)
|
| 125 |
+
|
| 126 |
+
padding_mask = to_numpy(padding_mask)
|
| 127 |
+
|
| 128 |
+
# match the sequence length of the padding mask to the generated audio arrays by padding with the **non-padding**
|
| 129 |
+
# token (so that the generated audio values are **not** treated as padded tokens)
|
| 130 |
+
difference = seq_len - padding_mask.shape[-1]
|
| 131 |
+
padding_value = 1 - self.feature_extractor.padding_value
|
| 132 |
+
padding_mask = np.pad(padding_mask, ((0, 0), (0, difference)), "constant", constant_values=padding_value)
|
| 133 |
+
|
| 134 |
+
audio_values = audio_values.tolist()
|
| 135 |
+
for i in range(bsz):
|
| 136 |
+
sliced_audio = np.asarray(audio_values[i])[
|
| 137 |
+
padding_mask[i][None, :] != self.feature_extractor.padding_value
|
| 138 |
+
]
|
| 139 |
+
audio_values[i] = sliced_audio.reshape(channels, -1)
|
| 140 |
+
|
| 141 |
+
return audio_values
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
__all__ = ["MusicgenProcessor"]
|
.venv/lib/python3.11/site-packages/transformers/models/olmoe/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_olmoe import *
|
| 22 |
+
from .modeling_olmoe import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (763 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/configuration_olmoe.cpython-311.pyc
ADDED
|
Binary file (8.47 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/modeling_olmoe.cpython-311.pyc
ADDED
|
Binary file (66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/olmoe/configuration_olmoe.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 2 |
+
# you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at
|
| 4 |
+
#
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
#
|
| 7 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 10 |
+
# See the License for the specific language governing permissions and
|
| 11 |
+
# limitations under the License.
|
| 12 |
+
"""OLMoE model configuration"""
|
| 13 |
+
|
| 14 |
+
from ...configuration_utils import PretrainedConfig
|
| 15 |
+
from ...modeling_rope_utils import rope_config_validation
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class OlmoeConfig(PretrainedConfig):
|
| 19 |
+
r"""
|
| 20 |
+
This is the configuration class to store the configuration of a [`OlmoeModel`]. It is used to instantiate an OLMoE
|
| 21 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 22 |
+
defaults will yield a similar configuration to that of the [allenai/OLMoE-1B-7B-0924](https://huggingface.co/allenai/OLMoE-1B-7B-0924).
|
| 23 |
+
|
| 24 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 25 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
vocab_size (`int`, *optional*, defaults to 50304):
|
| 30 |
+
Vocabulary size of the OLMoE model. Defines the number of different tokens that can be represented by the
|
| 31 |
+
`inputs_ids` passed when calling [`OlmoeModel`]
|
| 32 |
+
hidden_size (`int`, *optional*, defaults to 2048):
|
| 33 |
+
Dimension of the hidden representations.
|
| 34 |
+
intermediate_size (`int`, *optional*, defaults to 2048):
|
| 35 |
+
Dimension of the MLP representations.
|
| 36 |
+
num_hidden_layers (`int`, *optional*, defaults to 16):
|
| 37 |
+
Number of hidden layers in the Transformer decoder.
|
| 38 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 39 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 40 |
+
num_key_value_heads (`int`, *optional*):
|
| 41 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 42 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 43 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 44 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 45 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 46 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 47 |
+
`num_attention_heads`.
|
| 48 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 49 |
+
The non-linear activation function (function or string) in the decoder.
|
| 50 |
+
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
| 51 |
+
The maximum sequence length that this model might ever be used with.
|
| 52 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 53 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 54 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 55 |
+
The epsilon used by the rms normalization layers.
|
| 56 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 57 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 58 |
+
relevant if `config.is_decoder=True`.
|
| 59 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
| 60 |
+
Padding token id.
|
| 61 |
+
bos_token_id (`int`, *optional*):
|
| 62 |
+
Beginning of stream token id.
|
| 63 |
+
eos_token_id (`int`, *optional*, defaults to 50279):
|
| 64 |
+
End of stream token id.
|
| 65 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 66 |
+
Whether to tie weight embeddings
|
| 67 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 68 |
+
The base period of the RoPE embeddings.
|
| 69 |
+
rope_scaling (`Dict`, *optional*):
|
| 70 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
| 71 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
| 72 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
| 73 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
| 74 |
+
these scaling strategies behave:
|
| 75 |
+
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
| 76 |
+
experimental feature, subject to breaking API changes in future versions.
|
| 77 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 78 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 79 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 80 |
+
The dropout ratio for the attention probabilities.
|
| 81 |
+
clip_qkv (`float`, *optional*):
|
| 82 |
+
If not `None`, elements of query, key and value attention states are clipped so that their
|
| 83 |
+
absolute value does not exceed this value.
|
| 84 |
+
num_experts_per_tok (`int`, *optional*, defaults to 8):
|
| 85 |
+
Number of selected experts.
|
| 86 |
+
num_experts (`int`, *optional*, defaults to 64):
|
| 87 |
+
Number of routed experts.
|
| 88 |
+
output_router_logits (`bool`, *optional*, defaults to `False`):
|
| 89 |
+
Whether or not the router logits should be returned by the model. Enabeling this will also
|
| 90 |
+
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
|
| 91 |
+
router_aux_loss_coef (`float`, *optional*, defaults to 0.01):
|
| 92 |
+
The aux loss factor for the total loss.
|
| 93 |
+
norm_topk_prob (`bool`, *optional*, defaults to `False`):
|
| 94 |
+
Whether to normalize the topk probabilities.
|
| 95 |
+
|
| 96 |
+
```python
|
| 97 |
+
>>> from transformers import OlmoeModel, OlmoeConfig
|
| 98 |
+
|
| 99 |
+
>>> # Initializing a OLMoE 7B A1B style configuration
|
| 100 |
+
>>> configuration = OlmoeConfig()
|
| 101 |
+
|
| 102 |
+
>>> # Initializing a model from the OLMoE 7B A1B style configuration
|
| 103 |
+
>>> model = OlmoeModel(configuration)
|
| 104 |
+
|
| 105 |
+
>>> # Accessing the model configuration
|
| 106 |
+
>>> configuration = model.config
|
| 107 |
+
```"""
|
| 108 |
+
|
| 109 |
+
model_type = "olmoe"
|
| 110 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 111 |
+
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
vocab_size=50304,
|
| 115 |
+
hidden_size=2048,
|
| 116 |
+
intermediate_size=2048,
|
| 117 |
+
num_hidden_layers=16,
|
| 118 |
+
num_attention_heads=16,
|
| 119 |
+
num_key_value_heads=None,
|
| 120 |
+
hidden_act="silu",
|
| 121 |
+
max_position_embeddings=4096,
|
| 122 |
+
initializer_range=0.02,
|
| 123 |
+
rms_norm_eps=1e-05,
|
| 124 |
+
use_cache=True,
|
| 125 |
+
pad_token_id=1,
|
| 126 |
+
bos_token_id=None,
|
| 127 |
+
eos_token_id=50279,
|
| 128 |
+
tie_word_embeddings=False,
|
| 129 |
+
rope_theta=10000.0,
|
| 130 |
+
rope_scaling=None,
|
| 131 |
+
attention_bias=False,
|
| 132 |
+
attention_dropout=0.0,
|
| 133 |
+
clip_qkv=None,
|
| 134 |
+
num_experts_per_tok=8,
|
| 135 |
+
num_experts=64,
|
| 136 |
+
output_router_logits=False,
|
| 137 |
+
router_aux_loss_coef=0.01,
|
| 138 |
+
norm_topk_prob=False,
|
| 139 |
+
**kwargs,
|
| 140 |
+
):
|
| 141 |
+
self.vocab_size = vocab_size
|
| 142 |
+
self.max_position_embeddings = max_position_embeddings
|
| 143 |
+
self.hidden_size = hidden_size
|
| 144 |
+
self.intermediate_size = intermediate_size
|
| 145 |
+
self.num_hidden_layers = num_hidden_layers
|
| 146 |
+
self.num_attention_heads = num_attention_heads
|
| 147 |
+
|
| 148 |
+
# for backward compatibility
|
| 149 |
+
if num_key_value_heads is None:
|
| 150 |
+
num_key_value_heads = num_attention_heads
|
| 151 |
+
|
| 152 |
+
self.num_key_value_heads = num_key_value_heads
|
| 153 |
+
self.hidden_act = hidden_act
|
| 154 |
+
self.initializer_range = initializer_range
|
| 155 |
+
self.rms_norm_eps = rms_norm_eps
|
| 156 |
+
self.use_cache = use_cache
|
| 157 |
+
self.rope_theta = rope_theta
|
| 158 |
+
self.rope_scaling = rope_scaling
|
| 159 |
+
self.attention_bias = attention_bias
|
| 160 |
+
self.attention_dropout = attention_dropout
|
| 161 |
+
self.clip_qkv = clip_qkv
|
| 162 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 163 |
+
self.num_experts = num_experts
|
| 164 |
+
self.output_router_logits = output_router_logits
|
| 165 |
+
self.router_aux_loss_coef = router_aux_loss_coef
|
| 166 |
+
self.norm_topk_prob = norm_topk_prob
|
| 167 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 168 |
+
# BC: if there is a 'type' field, move it to 'rope_type'.
|
| 169 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 170 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 171 |
+
rope_config_validation(self)
|
| 172 |
+
|
| 173 |
+
super().__init__(
|
| 174 |
+
pad_token_id=pad_token_id,
|
| 175 |
+
bos_token_id=bos_token_id,
|
| 176 |
+
eos_token_id=eos_token_id,
|
| 177 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 178 |
+
**kwargs,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
__all__ = ["OlmoeConfig"]
|
.venv/lib/python3.11/site-packages/transformers/models/olmoe/modeling_olmoe.py
ADDED
|
@@ -0,0 +1,1299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 2 |
+
# you may not use this file except in compliance with the License.
|
| 3 |
+
# You may obtain a copy of the License at
|
| 4 |
+
#
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
#
|
| 7 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 10 |
+
# See the License for the specific language governing permissions and
|
| 11 |
+
# limitations under the License.
|
| 12 |
+
"""PyTorch OLMoE model."""
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
from typing import List, Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
import torch.utils.checkpoint
|
| 20 |
+
from torch import nn
|
| 21 |
+
|
| 22 |
+
from ...activations import ACT2FN
|
| 23 |
+
from ...cache_utils import Cache, DynamicCache, StaticCache
|
| 24 |
+
from ...generation import GenerationMixin
|
| 25 |
+
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
| 26 |
+
from ...modeling_outputs import (
|
| 27 |
+
MoeCausalLMOutputWithPast,
|
| 28 |
+
MoeModelOutputWithPast,
|
| 29 |
+
)
|
| 30 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
| 31 |
+
from ...modeling_utils import PreTrainedModel
|
| 32 |
+
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
| 33 |
+
from ...utils import (
|
| 34 |
+
add_start_docstrings,
|
| 35 |
+
add_start_docstrings_to_model_forward,
|
| 36 |
+
is_flash_attn_2_available,
|
| 37 |
+
is_flash_attn_greater_or_equal_2_10,
|
| 38 |
+
logging,
|
| 39 |
+
replace_return_docstrings,
|
| 40 |
+
)
|
| 41 |
+
from .configuration_olmoe import OlmoeConfig
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if is_flash_attn_2_available():
|
| 45 |
+
from ...modeling_flash_attention_utils import _flash_attention_forward
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__)
|
| 49 |
+
|
| 50 |
+
_CONFIG_FOR_DOC = "OlmoeConfig"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
|
| 54 |
+
def load_balancing_loss_func(
|
| 55 |
+
gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
|
| 56 |
+
num_experts: Optional[int] = None,
|
| 57 |
+
top_k=2,
|
| 58 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 59 |
+
) -> Union[torch.Tensor, int]:
|
| 60 |
+
r"""
|
| 61 |
+
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
| 62 |
+
|
| 63 |
+
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
|
| 64 |
+
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
| 65 |
+
experts is too unbalanced.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
gate_logits:
|
| 69 |
+
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
| 70 |
+
shape [batch_size X sequence_length, num_experts].
|
| 71 |
+
num_experts:
|
| 72 |
+
Number of experts
|
| 73 |
+
top_k:
|
| 74 |
+
The number of experts to route per-token, can be also interpreted as the `top-k` routing
|
| 75 |
+
parameter.
|
| 76 |
+
attention_mask (`torch.Tensor`, *optional*):
|
| 77 |
+
The attention_mask used in forward function
|
| 78 |
+
shape [batch_size X sequence_length] if not None.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
The auxiliary loss.
|
| 82 |
+
"""
|
| 83 |
+
if gate_logits is None or not isinstance(gate_logits, tuple):
|
| 84 |
+
return 0
|
| 85 |
+
|
| 86 |
+
if isinstance(gate_logits, tuple):
|
| 87 |
+
compute_device = gate_logits[0].device
|
| 88 |
+
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
| 89 |
+
|
| 90 |
+
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
| 91 |
+
|
| 92 |
+
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
| 93 |
+
|
| 94 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
| 95 |
+
|
| 96 |
+
if attention_mask is None:
|
| 97 |
+
# Compute the percentage of tokens routed to each experts
|
| 98 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
| 99 |
+
|
| 100 |
+
# Compute the average probability of routing to these experts
|
| 101 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
| 102 |
+
else:
|
| 103 |
+
batch_size, sequence_length = attention_mask.shape
|
| 104 |
+
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
| 105 |
+
|
| 106 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
| 107 |
+
expert_attention_mask = (
|
| 108 |
+
attention_mask[None, :, :, None, None]
|
| 109 |
+
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
|
| 110 |
+
.reshape(-1, top_k, num_experts)
|
| 111 |
+
.to(compute_device)
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Compute the percentage of tokens routed to each experts
|
| 115 |
+
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
|
| 116 |
+
expert_attention_mask, dim=0
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
| 120 |
+
router_per_expert_attention_mask = (
|
| 121 |
+
attention_mask[None, :, :, None]
|
| 122 |
+
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
| 123 |
+
.reshape(-1, num_experts)
|
| 124 |
+
.to(compute_device)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Compute the average probability of routing to these experts
|
| 128 |
+
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
| 129 |
+
router_per_expert_attention_mask, dim=0
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
| 133 |
+
return overall_loss * num_experts
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class OlmoeRMSNorm(nn.Module):
|
| 137 |
+
def __init__(self, hidden_size, eps=1e-5):
|
| 138 |
+
"""
|
| 139 |
+
OlmoeRMSNorm is equivalent to T5LayerNorm
|
| 140 |
+
"""
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 143 |
+
self.variance_epsilon = eps
|
| 144 |
+
|
| 145 |
+
def forward(self, hidden_states):
|
| 146 |
+
input_dtype = hidden_states.dtype
|
| 147 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 148 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 149 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 150 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 151 |
+
|
| 152 |
+
def extra_repr(self):
|
| 153 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
ALL_LAYERNORM_LAYERS.append(OlmoeRMSNorm)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmoe
|
| 160 |
+
class OlmoeRotaryEmbedding(nn.Module):
|
| 161 |
+
def __init__(self, config: OlmoeConfig, device=None):
|
| 162 |
+
super().__init__()
|
| 163 |
+
# BC: "rope_type" was originally "type"
|
| 164 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 165 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 166 |
+
else:
|
| 167 |
+
self.rope_type = "default"
|
| 168 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 169 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 170 |
+
|
| 171 |
+
self.config = config
|
| 172 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 173 |
+
|
| 174 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 175 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 176 |
+
self.original_inv_freq = self.inv_freq
|
| 177 |
+
|
| 178 |
+
def _dynamic_frequency_update(self, position_ids, device):
|
| 179 |
+
"""
|
| 180 |
+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
| 181 |
+
1 - growing beyond the cached sequence length (allow scaling)
|
| 182 |
+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
| 183 |
+
"""
|
| 184 |
+
seq_len = torch.max(position_ids) + 1
|
| 185 |
+
if seq_len > self.max_seq_len_cached: # growth
|
| 186 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
| 187 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
| 188 |
+
self.max_seq_len_cached = seq_len
|
| 189 |
+
|
| 190 |
+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
| 191 |
+
# This .to() is needed if the model has been moved to a device after being initialized (because
|
| 192 |
+
# the buffer is automatically moved, but not the original copy)
|
| 193 |
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
| 194 |
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
| 195 |
+
self.max_seq_len_cached = self.original_max_seq_len
|
| 196 |
+
|
| 197 |
+
@torch.no_grad()
|
| 198 |
+
def forward(self, x, position_ids):
|
| 199 |
+
if "dynamic" in self.rope_type:
|
| 200 |
+
self._dynamic_frequency_update(position_ids, device=x.device)
|
| 201 |
+
|
| 202 |
+
# Core RoPE block
|
| 203 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 204 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 205 |
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
| 206 |
+
device_type = x.device.type
|
| 207 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 208 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 209 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 210 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 211 |
+
cos = emb.cos()
|
| 212 |
+
sin = emb.sin()
|
| 213 |
+
|
| 214 |
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
| 215 |
+
cos = cos * self.attention_scaling
|
| 216 |
+
sin = sin * self.attention_scaling
|
| 217 |
+
|
| 218 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 222 |
+
def rotate_half(x):
|
| 223 |
+
"""Rotates half the hidden dims of the input."""
|
| 224 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 225 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 226 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
| 230 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 231 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
q (`torch.Tensor`): The query tensor.
|
| 235 |
+
k (`torch.Tensor`): The key tensor.
|
| 236 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 237 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 238 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 239 |
+
Deprecated and unused.
|
| 240 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 241 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 242 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 243 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 244 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 245 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 246 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 247 |
+
Returns:
|
| 248 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 249 |
+
"""
|
| 250 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 251 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 252 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 253 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 254 |
+
return q_embed, k_embed
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# Copied from transformers.models.olmo.modeling_olmo.OlmoMLP with Olmo->Olmoe
|
| 258 |
+
class OlmoeMLP(nn.Module):
|
| 259 |
+
def __init__(self, config):
|
| 260 |
+
super().__init__()
|
| 261 |
+
self.config = config
|
| 262 |
+
self.hidden_size = config.hidden_size
|
| 263 |
+
self.intermediate_size = config.intermediate_size
|
| 264 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 265 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 266 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 267 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 268 |
+
|
| 269 |
+
def forward(self, x):
|
| 270 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 271 |
+
return down_proj
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 275 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 276 |
+
"""
|
| 277 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 278 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 279 |
+
"""
|
| 280 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 281 |
+
if n_rep == 1:
|
| 282 |
+
return hidden_states
|
| 283 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 284 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class OlmoeAttention(nn.Module):
|
| 288 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 289 |
+
|
| 290 |
+
def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None):
|
| 291 |
+
super().__init__()
|
| 292 |
+
self.config = config
|
| 293 |
+
self.layer_idx = layer_idx
|
| 294 |
+
if layer_idx is None:
|
| 295 |
+
logger.warning_once(
|
| 296 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
| 297 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
| 298 |
+
"when creating this class."
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
self.attention_dropout = config.attention_dropout
|
| 302 |
+
self.hidden_size = config.hidden_size
|
| 303 |
+
self.num_heads = config.num_attention_heads
|
| 304 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 305 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 306 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 307 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 308 |
+
self.rope_theta = config.rope_theta
|
| 309 |
+
self.is_causal = True
|
| 310 |
+
|
| 311 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 312 |
+
raise ValueError(
|
| 313 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 314 |
+
f" and `num_heads`: {self.num_heads})."
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
| 318 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
| 319 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
| 320 |
+
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
| 321 |
+
self.q_norm = OlmoeRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
| 322 |
+
self.k_norm = OlmoeRMSNorm(
|
| 323 |
+
(self.hidden_size // self.num_heads) * self.num_key_value_heads, eps=config.rms_norm_eps
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
def forward(
|
| 327 |
+
self,
|
| 328 |
+
hidden_states: torch.Tensor,
|
| 329 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 330 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 331 |
+
past_key_value: Optional[Cache] = None,
|
| 332 |
+
output_attentions: bool = False,
|
| 333 |
+
use_cache: bool = False,
|
| 334 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 335 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 336 |
+
**kwargs,
|
| 337 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 338 |
+
bsz, q_len, _ = hidden_states.size()
|
| 339 |
+
|
| 340 |
+
query_states = self.q_norm(self.q_proj(hidden_states))
|
| 341 |
+
key_states = self.k_norm(self.k_proj(hidden_states))
|
| 342 |
+
value_states = self.v_proj(hidden_states)
|
| 343 |
+
|
| 344 |
+
if self.config.clip_qkv is not None:
|
| 345 |
+
query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 346 |
+
key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 347 |
+
value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 348 |
+
|
| 349 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 350 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 351 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 352 |
+
|
| 353 |
+
cos, sin = position_embeddings
|
| 354 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 355 |
+
|
| 356 |
+
if past_key_value is not None:
|
| 357 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 358 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 359 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 360 |
+
|
| 361 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 362 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 363 |
+
|
| 364 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 365 |
+
|
| 366 |
+
if attention_mask is not None: # no matter the length, we just slice it
|
| 367 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 368 |
+
attn_weights = attn_weights + causal_mask
|
| 369 |
+
|
| 370 |
+
# upcast attention to fp32
|
| 371 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 372 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 373 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 374 |
+
|
| 375 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 376 |
+
raise ValueError(
|
| 377 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 378 |
+
f" {attn_output.size()}"
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 382 |
+
|
| 383 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 384 |
+
|
| 385 |
+
attn_output = self.o_proj(attn_output)
|
| 386 |
+
|
| 387 |
+
if not output_attentions:
|
| 388 |
+
attn_weights = None
|
| 389 |
+
|
| 390 |
+
return attn_output, attn_weights, past_key_value
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
class OlmoeFlashAttention2(OlmoeAttention):
|
| 394 |
+
"""
|
| 395 |
+
OLMoE flash attention module. This module inherits from `OlmoeAttention` as the weights of the module stays
|
| 396 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 397 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 398 |
+
"""
|
| 399 |
+
|
| 400 |
+
def __init__(self, *args, **kwargs):
|
| 401 |
+
super().__init__(*args, **kwargs)
|
| 402 |
+
|
| 403 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 404 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 405 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 406 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 407 |
+
|
| 408 |
+
def forward(
|
| 409 |
+
self,
|
| 410 |
+
hidden_states: torch.Tensor,
|
| 411 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 412 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 413 |
+
past_key_value: Optional[Cache] = None,
|
| 414 |
+
output_attentions: bool = False,
|
| 415 |
+
use_cache: bool = False,
|
| 416 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 417 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 418 |
+
**kwargs,
|
| 419 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 420 |
+
output_attentions = False
|
| 421 |
+
|
| 422 |
+
bsz, q_len, _ = hidden_states.size()
|
| 423 |
+
|
| 424 |
+
query_states = self.q_norm(self.q_proj(hidden_states))
|
| 425 |
+
key_states = self.k_norm(self.k_proj(hidden_states))
|
| 426 |
+
value_states = self.v_proj(hidden_states)
|
| 427 |
+
if self.config.clip_qkv is not None:
|
| 428 |
+
query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 429 |
+
key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 430 |
+
value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 431 |
+
|
| 432 |
+
# Flash attention requires the input to have the shape
|
| 433 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
| 434 |
+
# therefore we just need to keep the original shape
|
| 435 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 436 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 437 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 438 |
+
|
| 439 |
+
cos, sin = position_embeddings
|
| 440 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 441 |
+
|
| 442 |
+
if past_key_value is not None:
|
| 443 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 444 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 445 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 446 |
+
|
| 447 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
| 448 |
+
# to be able to avoid many of these transpose/reshape/view.
|
| 449 |
+
query_states = query_states.transpose(1, 2)
|
| 450 |
+
key_states = key_states.transpose(1, 2)
|
| 451 |
+
value_states = value_states.transpose(1, 2)
|
| 452 |
+
|
| 453 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
| 454 |
+
|
| 455 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 456 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 457 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
| 458 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
| 459 |
+
# in fp32. (OlmoeRMSNorm handles it correctly)
|
| 460 |
+
|
| 461 |
+
input_dtype = query_states.dtype
|
| 462 |
+
if input_dtype == torch.float32:
|
| 463 |
+
if torch.is_autocast_enabled():
|
| 464 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 465 |
+
# Handle the case where the model is quantized
|
| 466 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 467 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 468 |
+
else:
|
| 469 |
+
target_dtype = self.q_proj.weight.dtype
|
| 470 |
+
|
| 471 |
+
logger.warning_once(
|
| 472 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 473 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 474 |
+
f" {target_dtype}."
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
query_states = query_states.to(target_dtype)
|
| 478 |
+
key_states = key_states.to(target_dtype)
|
| 479 |
+
value_states = value_states.to(target_dtype)
|
| 480 |
+
|
| 481 |
+
attn_output = _flash_attention_forward(
|
| 482 |
+
query_states,
|
| 483 |
+
key_states,
|
| 484 |
+
value_states,
|
| 485 |
+
attention_mask,
|
| 486 |
+
q_len,
|
| 487 |
+
dropout=dropout_rate,
|
| 488 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 489 |
+
is_causal=self.is_causal,
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
| 493 |
+
attn_output = self.o_proj(attn_output)
|
| 494 |
+
|
| 495 |
+
if not output_attentions:
|
| 496 |
+
attn_weights = None
|
| 497 |
+
|
| 498 |
+
return attn_output, attn_weights, past_key_value
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class OlmoeSdpaAttention(OlmoeAttention):
|
| 502 |
+
"""
|
| 503 |
+
OLMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 504 |
+
`OlmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
| 505 |
+
SDPA API.
|
| 506 |
+
"""
|
| 507 |
+
|
| 508 |
+
# Adapted from OlmoeAttention.forward
|
| 509 |
+
def forward(
|
| 510 |
+
self,
|
| 511 |
+
hidden_states: torch.Tensor,
|
| 512 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 513 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 514 |
+
past_key_value: Optional[Cache] = None,
|
| 515 |
+
output_attentions: bool = False,
|
| 516 |
+
use_cache: bool = False,
|
| 517 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 518 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 519 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 520 |
+
if output_attentions:
|
| 521 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
| 522 |
+
logger.warning_once(
|
| 523 |
+
"OlmoeModel is using OlmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
| 524 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 525 |
+
)
|
| 526 |
+
return super().forward(
|
| 527 |
+
hidden_states=hidden_states,
|
| 528 |
+
attention_mask=attention_mask,
|
| 529 |
+
position_ids=position_ids,
|
| 530 |
+
past_key_value=past_key_value,
|
| 531 |
+
output_attentions=output_attentions,
|
| 532 |
+
use_cache=use_cache,
|
| 533 |
+
cache_position=cache_position,
|
| 534 |
+
position_embeddings=position_embeddings,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
bsz, q_len, _ = hidden_states.size()
|
| 538 |
+
|
| 539 |
+
query_states = self.q_norm(self.q_proj(hidden_states))
|
| 540 |
+
key_states = self.k_norm(self.k_proj(hidden_states))
|
| 541 |
+
value_states = self.v_proj(hidden_states)
|
| 542 |
+
|
| 543 |
+
if self.config.clip_qkv is not None:
|
| 544 |
+
query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 545 |
+
key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 546 |
+
value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
| 547 |
+
|
| 548 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 549 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 550 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 551 |
+
|
| 552 |
+
cos, sin = position_embeddings
|
| 553 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 554 |
+
|
| 555 |
+
if past_key_value is not None:
|
| 556 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 557 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 558 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 559 |
+
|
| 560 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 561 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 562 |
+
|
| 563 |
+
causal_mask = attention_mask
|
| 564 |
+
# if attention_mask is not None and cache_position is not None:
|
| 565 |
+
if attention_mask is not None:
|
| 566 |
+
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
| 567 |
+
|
| 568 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 569 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 570 |
+
if query_states.device.type == "cuda" and causal_mask is not None:
|
| 571 |
+
query_states = query_states.contiguous()
|
| 572 |
+
key_states = key_states.contiguous()
|
| 573 |
+
value_states = value_states.contiguous()
|
| 574 |
+
|
| 575 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
| 576 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
| 577 |
+
is_causal = True if causal_mask is None and q_len > 1 else False
|
| 578 |
+
|
| 579 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 580 |
+
query_states,
|
| 581 |
+
key_states,
|
| 582 |
+
value_states,
|
| 583 |
+
attn_mask=causal_mask,
|
| 584 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 585 |
+
is_causal=is_causal,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 589 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
| 590 |
+
|
| 591 |
+
attn_output = self.o_proj(attn_output)
|
| 592 |
+
|
| 593 |
+
return attn_output, None, past_key_value
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
OLMOE_ATTENTION_CLASSES = {
|
| 597 |
+
"eager": OlmoeAttention,
|
| 598 |
+
"flash_attention_2": OlmoeFlashAttention2,
|
| 599 |
+
"sdpa": OlmoeSdpaAttention,
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class OlmoeSparseMoeBlock(nn.Module):
|
| 604 |
+
def __init__(self, config):
|
| 605 |
+
super().__init__()
|
| 606 |
+
self.num_experts = config.num_experts
|
| 607 |
+
self.top_k = config.num_experts_per_tok
|
| 608 |
+
self.norm_topk_prob = config.norm_topk_prob
|
| 609 |
+
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
| 610 |
+
self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
|
| 611 |
+
|
| 612 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 613 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 614 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 615 |
+
# router_logits: (batch * sequence_length, n_experts)
|
| 616 |
+
router_logits = self.gate(hidden_states)
|
| 617 |
+
|
| 618 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 619 |
+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
| 620 |
+
if self.norm_topk_prob:
|
| 621 |
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
| 622 |
+
# we cast back to the input dtype
|
| 623 |
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
| 624 |
+
|
| 625 |
+
final_hidden_states = torch.zeros(
|
| 626 |
+
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
# One hot encode the selected experts to create an expert mask
|
| 630 |
+
# this will be used to easily index which expert is going to be selected
|
| 631 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
| 632 |
+
|
| 633 |
+
# Loop over all available experts in the model and perform the computation on each expert
|
| 634 |
+
for expert_idx in range(self.num_experts):
|
| 635 |
+
expert_layer = self.experts[expert_idx]
|
| 636 |
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
| 637 |
+
|
| 638 |
+
# Index the correct hidden states and compute the expert hidden state for
|
| 639 |
+
# the current expert. We need to make sure to multiply the output hidden
|
| 640 |
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
| 641 |
+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
| 642 |
+
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
| 643 |
+
|
| 644 |
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
| 645 |
+
# the `top_x` tensor here.
|
| 646 |
+
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 647 |
+
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
| 648 |
+
return final_hidden_states, router_logits
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
class OlmoeDecoderLayer(nn.Module):
|
| 652 |
+
def __init__(self, config: OlmoeConfig, layer_idx: int):
|
| 653 |
+
super().__init__()
|
| 654 |
+
self.hidden_size = config.hidden_size
|
| 655 |
+
|
| 656 |
+
self.self_attn = OLMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
| 657 |
+
|
| 658 |
+
self.mlp = OlmoeSparseMoeBlock(config)
|
| 659 |
+
self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 660 |
+
self.post_attention_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 661 |
+
|
| 662 |
+
def forward(
|
| 663 |
+
self,
|
| 664 |
+
hidden_states: torch.Tensor,
|
| 665 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 666 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 667 |
+
past_key_value: Optional[Cache] = None,
|
| 668 |
+
output_attentions: Optional[bool] = False,
|
| 669 |
+
output_router_logits: Optional[bool] = False,
|
| 670 |
+
use_cache: Optional[bool] = False,
|
| 671 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 672 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 673 |
+
**kwargs,
|
| 674 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 675 |
+
"""
|
| 676 |
+
Args:
|
| 677 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 678 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
| 679 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
| 680 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
| 681 |
+
output_attentions (`bool`, *optional*):
|
| 682 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 683 |
+
returned tensors for more detail.
|
| 684 |
+
output_router_logits (`bool`, *optional*):
|
| 685 |
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss,
|
| 686 |
+
and should not be returned during inference.
|
| 687 |
+
use_cache (`bool`, *optional*):
|
| 688 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 689 |
+
(see `past_key_values`).
|
| 690 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 691 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 692 |
+
Indices depicting the position of the input sequence tokens in the sequence
|
| 693 |
+
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
| 694 |
+
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
| 695 |
+
with `head_dim` being the embedding dimension of each attention head.
|
| 696 |
+
kwargs (`dict`, *optional*):
|
| 697 |
+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
| 698 |
+
into the model
|
| 699 |
+
"""
|
| 700 |
+
residual = hidden_states
|
| 701 |
+
|
| 702 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 703 |
+
|
| 704 |
+
# Self Attention
|
| 705 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 706 |
+
hidden_states=hidden_states,
|
| 707 |
+
attention_mask=attention_mask,
|
| 708 |
+
position_ids=position_ids,
|
| 709 |
+
past_key_value=past_key_value,
|
| 710 |
+
output_attentions=output_attentions,
|
| 711 |
+
use_cache=use_cache,
|
| 712 |
+
cache_position=cache_position,
|
| 713 |
+
position_embeddings=position_embeddings,
|
| 714 |
+
**kwargs,
|
| 715 |
+
)
|
| 716 |
+
hidden_states = residual + hidden_states
|
| 717 |
+
|
| 718 |
+
# Fully Connected
|
| 719 |
+
residual = hidden_states
|
| 720 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 721 |
+
hidden_states, router_logits = self.mlp(hidden_states)
|
| 722 |
+
hidden_states = residual + hidden_states
|
| 723 |
+
|
| 724 |
+
outputs = (hidden_states,)
|
| 725 |
+
|
| 726 |
+
if output_attentions:
|
| 727 |
+
outputs += (self_attn_weights,)
|
| 728 |
+
|
| 729 |
+
if use_cache:
|
| 730 |
+
outputs += (present_key_value,)
|
| 731 |
+
|
| 732 |
+
if output_router_logits:
|
| 733 |
+
outputs += (router_logits,)
|
| 734 |
+
|
| 735 |
+
return outputs
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
OLMOE_START_DOCSTRING = r"""
|
| 739 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 740 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 741 |
+
etc.)
|
| 742 |
+
|
| 743 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 744 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 745 |
+
and behavior.
|
| 746 |
+
|
| 747 |
+
Parameters:
|
| 748 |
+
config ([`OlmoeConfig`]):
|
| 749 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 750 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 751 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 752 |
+
"""
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
@add_start_docstrings(
|
| 756 |
+
"The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
|
| 757 |
+
OLMOE_START_DOCSTRING,
|
| 758 |
+
)
|
| 759 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Olmoe
|
| 760 |
+
class OlmoePreTrainedModel(PreTrainedModel):
|
| 761 |
+
config_class = OlmoeConfig
|
| 762 |
+
base_model_prefix = "model"
|
| 763 |
+
supports_gradient_checkpointing = True
|
| 764 |
+
_no_split_modules = ["OlmoeDecoderLayer"]
|
| 765 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 766 |
+
_supports_flash_attn_2 = True
|
| 767 |
+
_supports_sdpa = True
|
| 768 |
+
_supports_flex_attn = True
|
| 769 |
+
_supports_cache_class = True
|
| 770 |
+
_supports_quantized_cache = True
|
| 771 |
+
_supports_static_cache = True
|
| 772 |
+
|
| 773 |
+
def _init_weights(self, module):
|
| 774 |
+
std = self.config.initializer_range
|
| 775 |
+
if isinstance(module, nn.Linear):
|
| 776 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 777 |
+
if module.bias is not None:
|
| 778 |
+
module.bias.data.zero_()
|
| 779 |
+
elif isinstance(module, nn.Embedding):
|
| 780 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 781 |
+
if module.padding_idx is not None:
|
| 782 |
+
module.weight.data[module.padding_idx].zero_()
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
OLMOE_INPUTS_DOCSTRING = r"""
|
| 786 |
+
Args:
|
| 787 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 788 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 789 |
+
it.
|
| 790 |
+
|
| 791 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 792 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 793 |
+
|
| 794 |
+
[What are input IDs?](../glossary#input-ids)
|
| 795 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 796 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 797 |
+
|
| 798 |
+
- 1 for tokens that are **not masked**,
|
| 799 |
+
- 0 for tokens that are **masked**.
|
| 800 |
+
|
| 801 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 802 |
+
|
| 803 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 804 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 805 |
+
|
| 806 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
| 807 |
+
`past_key_values`).
|
| 808 |
+
|
| 809 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
| 810 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
| 811 |
+
information on the default strategy.
|
| 812 |
+
|
| 813 |
+
- 1 indicates the head is **not masked**,
|
| 814 |
+
- 0 indicates the head is **masked**.
|
| 815 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 816 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 817 |
+
config.n_positions - 1]`.
|
| 818 |
+
|
| 819 |
+
[What are position IDs?](../glossary#position-ids)
|
| 820 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
| 821 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
| 822 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
| 823 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
| 824 |
+
|
| 825 |
+
Two formats are allowed:
|
| 826 |
+
- a [`~cache_utils.Cache`] instance;
|
| 827 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
| 828 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
| 829 |
+
cache format.
|
| 830 |
+
|
| 831 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
| 832 |
+
legacy cache format will be returned.
|
| 833 |
+
|
| 834 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
| 835 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
| 836 |
+
of shape `(batch_size, sequence_length)`.
|
| 837 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
| 838 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 839 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 840 |
+
model's internal embedding lookup matrix.
|
| 841 |
+
use_cache (`bool`, *optional*):
|
| 842 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 843 |
+
`past_key_values`).
|
| 844 |
+
output_attentions (`bool`, *optional*):
|
| 845 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 846 |
+
tensors for more detail.
|
| 847 |
+
output_hidden_states (`bool`, *optional*):
|
| 848 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 849 |
+
more detail.
|
| 850 |
+
output_router_logits (`bool`, *optional*):
|
| 851 |
+
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
| 852 |
+
should not be returned during inference.
|
| 853 |
+
return_dict (`bool`, *optional*):
|
| 854 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 855 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 856 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
| 857 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
| 858 |
+
the complete sequence length.
|
| 859 |
+
"""
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
@add_start_docstrings(
|
| 863 |
+
"The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
|
| 864 |
+
OLMOE_START_DOCSTRING,
|
| 865 |
+
)
|
| 866 |
+
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe
|
| 867 |
+
class OlmoeModel(OlmoePreTrainedModel):
|
| 868 |
+
"""
|
| 869 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoeDecoderLayer`]
|
| 870 |
+
|
| 871 |
+
Args:
|
| 872 |
+
config: OlmoeConfig
|
| 873 |
+
"""
|
| 874 |
+
|
| 875 |
+
def __init__(self, config: OlmoeConfig):
|
| 876 |
+
super().__init__(config)
|
| 877 |
+
self.padding_idx = config.pad_token_id
|
| 878 |
+
self.vocab_size = config.vocab_size
|
| 879 |
+
|
| 880 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 881 |
+
self.layers = nn.ModuleList(
|
| 882 |
+
[OlmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 883 |
+
)
|
| 884 |
+
self.norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 885 |
+
self.rotary_emb = OlmoeRotaryEmbedding(config=config)
|
| 886 |
+
self.gradient_checkpointing = False
|
| 887 |
+
|
| 888 |
+
# Initialize weights and apply final processing
|
| 889 |
+
self.post_init()
|
| 890 |
+
|
| 891 |
+
def get_input_embeddings(self):
|
| 892 |
+
return self.embed_tokens
|
| 893 |
+
|
| 894 |
+
def set_input_embeddings(self, value):
|
| 895 |
+
self.embed_tokens = value
|
| 896 |
+
|
| 897 |
+
@add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING)
|
| 898 |
+
# Ignore copy
|
| 899 |
+
def forward(
|
| 900 |
+
self,
|
| 901 |
+
input_ids: torch.LongTensor = None,
|
| 902 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 903 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 904 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
| 905 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 906 |
+
use_cache: Optional[bool] = None,
|
| 907 |
+
output_attentions: Optional[bool] = None,
|
| 908 |
+
output_hidden_states: Optional[bool] = None,
|
| 909 |
+
output_router_logits: Optional[bool] = None,
|
| 910 |
+
return_dict: Optional[bool] = None,
|
| 911 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 912 |
+
) -> Union[Tuple, MoeModelOutputWithPast]:
|
| 913 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 914 |
+
output_router_logits = (
|
| 915 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 916 |
+
)
|
| 917 |
+
output_hidden_states = (
|
| 918 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 919 |
+
)
|
| 920 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 921 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 922 |
+
|
| 923 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 924 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 925 |
+
|
| 926 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 927 |
+
logger.warning_once(
|
| 928 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 929 |
+
)
|
| 930 |
+
use_cache = False
|
| 931 |
+
|
| 932 |
+
if inputs_embeds is None:
|
| 933 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 934 |
+
|
| 935 |
+
# kept for BC (non `Cache` `past_key_values` inputs)
|
| 936 |
+
return_legacy_cache = False
|
| 937 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
| 938 |
+
return_legacy_cache = True
|
| 939 |
+
if past_key_values is None:
|
| 940 |
+
past_key_values = DynamicCache()
|
| 941 |
+
else:
|
| 942 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 943 |
+
logger.warning_once(
|
| 944 |
+
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
| 945 |
+
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
| 946 |
+
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
if cache_position is None:
|
| 950 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 951 |
+
cache_position = torch.arange(
|
| 952 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 953 |
+
)
|
| 954 |
+
if position_ids is None:
|
| 955 |
+
position_ids = cache_position.unsqueeze(0)
|
| 956 |
+
|
| 957 |
+
causal_mask = self._update_causal_mask(
|
| 958 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
# embed positions
|
| 962 |
+
hidden_states = inputs_embeds
|
| 963 |
+
|
| 964 |
+
# create position embeddings to be shared across the decoder layers
|
| 965 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 966 |
+
|
| 967 |
+
# decoder layers
|
| 968 |
+
all_hidden_states = () if output_hidden_states else None
|
| 969 |
+
all_self_attns = () if output_attentions else None
|
| 970 |
+
all_router_logits = () if output_router_logits else None
|
| 971 |
+
next_decoder_cache = None
|
| 972 |
+
|
| 973 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 974 |
+
if output_hidden_states:
|
| 975 |
+
all_hidden_states += (hidden_states,)
|
| 976 |
+
|
| 977 |
+
if self.gradient_checkpointing and self.training:
|
| 978 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 979 |
+
decoder_layer.__call__,
|
| 980 |
+
hidden_states,
|
| 981 |
+
causal_mask,
|
| 982 |
+
position_ids,
|
| 983 |
+
past_key_values,
|
| 984 |
+
output_attentions,
|
| 985 |
+
output_router_logits,
|
| 986 |
+
use_cache,
|
| 987 |
+
cache_position,
|
| 988 |
+
position_embeddings,
|
| 989 |
+
)
|
| 990 |
+
else:
|
| 991 |
+
layer_outputs = decoder_layer(
|
| 992 |
+
hidden_states,
|
| 993 |
+
attention_mask=causal_mask,
|
| 994 |
+
position_ids=position_ids,
|
| 995 |
+
past_key_value=past_key_values,
|
| 996 |
+
output_attentions=output_attentions,
|
| 997 |
+
output_router_logits=output_router_logits,
|
| 998 |
+
use_cache=use_cache,
|
| 999 |
+
cache_position=cache_position,
|
| 1000 |
+
position_embeddings=position_embeddings,
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
hidden_states = layer_outputs[0]
|
| 1004 |
+
|
| 1005 |
+
if use_cache:
|
| 1006 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 1007 |
+
|
| 1008 |
+
if output_attentions:
|
| 1009 |
+
all_self_attns += (layer_outputs[1],)
|
| 1010 |
+
|
| 1011 |
+
if output_router_logits and layer_outputs[-1] is not None:
|
| 1012 |
+
all_router_logits += (layer_outputs[-1],)
|
| 1013 |
+
|
| 1014 |
+
hidden_states = self.norm(hidden_states)
|
| 1015 |
+
|
| 1016 |
+
# add hidden states from the last decoder layer
|
| 1017 |
+
if output_hidden_states:
|
| 1018 |
+
all_hidden_states += (hidden_states,)
|
| 1019 |
+
|
| 1020 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 1021 |
+
if return_legacy_cache:
|
| 1022 |
+
next_cache = next_cache.to_legacy_cache()
|
| 1023 |
+
|
| 1024 |
+
if not return_dict:
|
| 1025 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 1026 |
+
return MoeModelOutputWithPast(
|
| 1027 |
+
last_hidden_state=hidden_states,
|
| 1028 |
+
past_key_values=next_cache,
|
| 1029 |
+
hidden_states=all_hidden_states,
|
| 1030 |
+
attentions=all_self_attns,
|
| 1031 |
+
router_logits=all_router_logits,
|
| 1032 |
+
)
|
| 1033 |
+
|
| 1034 |
+
def _update_causal_mask(
|
| 1035 |
+
self,
|
| 1036 |
+
attention_mask: torch.Tensor,
|
| 1037 |
+
input_tensor: torch.Tensor,
|
| 1038 |
+
cache_position: torch.Tensor,
|
| 1039 |
+
past_key_values: Cache,
|
| 1040 |
+
output_attentions: bool,
|
| 1041 |
+
):
|
| 1042 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 1043 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
| 1044 |
+
return attention_mask
|
| 1045 |
+
return None
|
| 1046 |
+
|
| 1047 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 1048 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 1049 |
+
# to infer the attention mask.
|
| 1050 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1051 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 1052 |
+
|
| 1053 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 1054 |
+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
| 1055 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 1056 |
+
attention_mask,
|
| 1057 |
+
inputs_embeds=input_tensor,
|
| 1058 |
+
past_key_values_length=past_seen_tokens,
|
| 1059 |
+
is_training=self.training,
|
| 1060 |
+
):
|
| 1061 |
+
return None
|
| 1062 |
+
|
| 1063 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
| 1064 |
+
sequence_length = input_tensor.shape[1]
|
| 1065 |
+
if using_static_cache:
|
| 1066 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 1067 |
+
else:
|
| 1068 |
+
target_length = (
|
| 1069 |
+
attention_mask.shape[-1]
|
| 1070 |
+
if isinstance(attention_mask, torch.Tensor)
|
| 1071 |
+
else past_seen_tokens + sequence_length + 1
|
| 1072 |
+
)
|
| 1073 |
+
|
| 1074 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 1075 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 1076 |
+
attention_mask,
|
| 1077 |
+
sequence_length=sequence_length,
|
| 1078 |
+
target_length=target_length,
|
| 1079 |
+
dtype=dtype,
|
| 1080 |
+
device=device,
|
| 1081 |
+
cache_position=cache_position,
|
| 1082 |
+
batch_size=input_tensor.shape[0],
|
| 1083 |
+
)
|
| 1084 |
+
|
| 1085 |
+
if (
|
| 1086 |
+
self.config._attn_implementation == "sdpa"
|
| 1087 |
+
and attention_mask is not None
|
| 1088 |
+
and attention_mask.device.type == "cuda"
|
| 1089 |
+
and not output_attentions
|
| 1090 |
+
):
|
| 1091 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 1092 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 1093 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 1094 |
+
min_dtype = torch.finfo(dtype).min
|
| 1095 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
| 1096 |
+
|
| 1097 |
+
return causal_mask
|
| 1098 |
+
|
| 1099 |
+
@staticmethod
|
| 1100 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 1101 |
+
attention_mask: torch.Tensor,
|
| 1102 |
+
sequence_length: int,
|
| 1103 |
+
target_length: int,
|
| 1104 |
+
dtype: torch.dtype,
|
| 1105 |
+
device: torch.device,
|
| 1106 |
+
cache_position: torch.Tensor,
|
| 1107 |
+
batch_size: int,
|
| 1108 |
+
**kwargs,
|
| 1109 |
+
):
|
| 1110 |
+
"""
|
| 1111 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 1112 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
| 1113 |
+
|
| 1114 |
+
Args:
|
| 1115 |
+
attention_mask (`torch.Tensor`):
|
| 1116 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
| 1117 |
+
`(batch_size, 1, query_length, key_value_length)`.
|
| 1118 |
+
sequence_length (`int`):
|
| 1119 |
+
The sequence length being processed.
|
| 1120 |
+
target_length (`int`):
|
| 1121 |
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
| 1122 |
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
| 1123 |
+
dtype (`torch.dtype`):
|
| 1124 |
+
The dtype to use for the 4D attention mask.
|
| 1125 |
+
device (`torch.device`):
|
| 1126 |
+
The device to plcae the 4D attention mask on.
|
| 1127 |
+
cache_position (`torch.Tensor`):
|
| 1128 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 1129 |
+
batch_size (`torch.Tensor`):
|
| 1130 |
+
Batch size.
|
| 1131 |
+
"""
|
| 1132 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 1133 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 1134 |
+
causal_mask = attention_mask
|
| 1135 |
+
else:
|
| 1136 |
+
min_dtype = torch.finfo(dtype).min
|
| 1137 |
+
causal_mask = torch.full(
|
| 1138 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
| 1139 |
+
)
|
| 1140 |
+
if sequence_length != 1:
|
| 1141 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 1142 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
| 1143 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
| 1144 |
+
if attention_mask is not None:
|
| 1145 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
| 1146 |
+
mask_length = attention_mask.shape[-1]
|
| 1147 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
| 1148 |
+
padding_mask = padding_mask == 0
|
| 1149 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
| 1150 |
+
padding_mask, min_dtype
|
| 1151 |
+
)
|
| 1152 |
+
|
| 1153 |
+
return causal_mask
|
| 1154 |
+
|
| 1155 |
+
|
| 1156 |
+
class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
|
| 1157 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1158 |
+
|
| 1159 |
+
def __init__(self, config):
|
| 1160 |
+
super().__init__(config)
|
| 1161 |
+
self.model = OlmoeModel(config)
|
| 1162 |
+
self.vocab_size = config.vocab_size
|
| 1163 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1164 |
+
|
| 1165 |
+
self.router_aux_loss_coef = config.router_aux_loss_coef
|
| 1166 |
+
self.num_experts = config.num_experts
|
| 1167 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
| 1168 |
+
# Initialize weights and apply final processing
|
| 1169 |
+
self.post_init()
|
| 1170 |
+
|
| 1171 |
+
def get_input_embeddings(self):
|
| 1172 |
+
return self.model.embed_tokens
|
| 1173 |
+
|
| 1174 |
+
def set_input_embeddings(self, value):
|
| 1175 |
+
self.model.embed_tokens = value
|
| 1176 |
+
|
| 1177 |
+
def get_output_embeddings(self):
|
| 1178 |
+
return self.lm_head
|
| 1179 |
+
|
| 1180 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1181 |
+
self.lm_head = new_embeddings
|
| 1182 |
+
|
| 1183 |
+
def set_decoder(self, decoder):
|
| 1184 |
+
self.model = decoder
|
| 1185 |
+
|
| 1186 |
+
def get_decoder(self):
|
| 1187 |
+
return self.model
|
| 1188 |
+
|
| 1189 |
+
@add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING)
|
| 1190 |
+
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 1191 |
+
def forward(
|
| 1192 |
+
self,
|
| 1193 |
+
input_ids: torch.LongTensor = None,
|
| 1194 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1195 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1196 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1197 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1198 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1199 |
+
use_cache: Optional[bool] = None,
|
| 1200 |
+
output_attentions: Optional[bool] = None,
|
| 1201 |
+
output_hidden_states: Optional[bool] = None,
|
| 1202 |
+
output_router_logits: Optional[bool] = None,
|
| 1203 |
+
return_dict: Optional[bool] = None,
|
| 1204 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1205 |
+
num_logits_to_keep: int = 0,
|
| 1206 |
+
**loss_kwargs,
|
| 1207 |
+
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
|
| 1208 |
+
r"""
|
| 1209 |
+
Args:
|
| 1210 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1211 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 1212 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 1213 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 1214 |
+
|
| 1215 |
+
num_logits_to_keep (`int`, *optional*):
|
| 1216 |
+
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
| 1217 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
| 1218 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
| 1219 |
+
|
| 1220 |
+
Returns:
|
| 1221 |
+
|
| 1222 |
+
Example:
|
| 1223 |
+
|
| 1224 |
+
```python
|
| 1225 |
+
>>> from transformers import AutoTokenizer, OlmoeForCausalLM
|
| 1226 |
+
|
| 1227 |
+
>>> model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
|
| 1228 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
|
| 1229 |
+
|
| 1230 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 1231 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 1232 |
+
|
| 1233 |
+
>>> # Generate
|
| 1234 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1235 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1236 |
+
'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
|
| 1237 |
+
```
|
| 1238 |
+
"""
|
| 1239 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1240 |
+
output_router_logits = (
|
| 1241 |
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
| 1242 |
+
)
|
| 1243 |
+
output_hidden_states = (
|
| 1244 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1245 |
+
)
|
| 1246 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1247 |
+
|
| 1248 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1249 |
+
outputs = self.model(
|
| 1250 |
+
input_ids=input_ids,
|
| 1251 |
+
attention_mask=attention_mask,
|
| 1252 |
+
position_ids=position_ids,
|
| 1253 |
+
past_key_values=past_key_values,
|
| 1254 |
+
inputs_embeds=inputs_embeds,
|
| 1255 |
+
use_cache=use_cache,
|
| 1256 |
+
output_attentions=output_attentions,
|
| 1257 |
+
output_hidden_states=output_hidden_states,
|
| 1258 |
+
output_router_logits=output_router_logits,
|
| 1259 |
+
return_dict=return_dict,
|
| 1260 |
+
cache_position=cache_position,
|
| 1261 |
+
)
|
| 1262 |
+
|
| 1263 |
+
hidden_states = outputs[0]
|
| 1264 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 1265 |
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
| 1266 |
+
|
| 1267 |
+
loss = None
|
| 1268 |
+
if labels is not None:
|
| 1269 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
| 1270 |
+
|
| 1271 |
+
aux_loss = None
|
| 1272 |
+
if output_router_logits:
|
| 1273 |
+
aux_loss = load_balancing_loss_func(
|
| 1274 |
+
outputs.router_logits if return_dict else outputs[-1],
|
| 1275 |
+
self.num_experts,
|
| 1276 |
+
self.num_experts_per_tok,
|
| 1277 |
+
attention_mask,
|
| 1278 |
+
)
|
| 1279 |
+
if labels is not None:
|
| 1280 |
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
| 1281 |
+
|
| 1282 |
+
if not return_dict:
|
| 1283 |
+
output = (logits,) + outputs[1:]
|
| 1284 |
+
if output_router_logits:
|
| 1285 |
+
output = (aux_loss,) + output
|
| 1286 |
+
return (loss,) + output if loss is not None else output
|
| 1287 |
+
|
| 1288 |
+
return MoeCausalLMOutputWithPast(
|
| 1289 |
+
loss=loss,
|
| 1290 |
+
aux_loss=aux_loss,
|
| 1291 |
+
logits=logits,
|
| 1292 |
+
past_key_values=outputs.past_key_values,
|
| 1293 |
+
hidden_states=outputs.hidden_states,
|
| 1294 |
+
attentions=outputs.attentions,
|
| 1295 |
+
router_logits=outputs.router_logits,
|
| 1296 |
+
)
|
| 1297 |
+
|
| 1298 |
+
|
| 1299 |
+
__all__ = ["OlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]
|
.venv/lib/python3.11/site-packages/transformers/models/pegasus/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_pegasus import *
|
| 22 |
+
from .modeling_flax_pegasus import *
|
| 23 |
+
from .modeling_pegasus import *
|
| 24 |
+
from .modeling_tf_pegasus import *
|
| 25 |
+
from .tokenization_pegasus import *
|
| 26 |
+
from .tokenization_pegasus_fast import *
|
| 27 |
+
else:
|
| 28 |
+
import sys
|
| 29 |
+
|
| 30 |
+
_file = globals()["__file__"]
|
| 31 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (942 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/configuration_pegasus.cpython-311.pyc
ADDED
|
Binary file (7.29 kB). View file
|
|
|