koichi12 commited on
Commit
d1d6563
·
verified ·
1 Parent(s): 9e4dfa6

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/transformers/models/barthez/__init__.py +27 -0
  2. .venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/__init__.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/tokenization_barthez.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/tokenization_barthez_fast.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/transformers/models/barthez/tokenization_barthez.py +289 -0
  6. .venv/lib/python3.11/site-packages/transformers/models/barthez/tokenization_barthez_fast.py +197 -0
  7. .venv/lib/python3.11/site-packages/transformers/models/beit/configuration_beit.py +229 -0
  8. .venv/lib/python3.11/site-packages/transformers/models/beit/feature_extraction_beit.py +36 -0
  9. .venv/lib/python3.11/site-packages/transformers/models/beit/image_processing_beit.py +515 -0
  10. .venv/lib/python3.11/site-packages/transformers/models/beit/modeling_flax_beit.py +956 -0
  11. .venv/lib/python3.11/site-packages/transformers/models/code_llama/__init__.py +27 -0
  12. .venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/__init__.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/tokenization_code_llama.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/tokenization_code_llama_fast.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama.py +452 -0
  16. .venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py +381 -0
  17. .venv/lib/python3.11/site-packages/transformers/models/convnext/__init__.py +30 -0
  18. .venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/configuration_convnext.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/image_processing_convnext.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/transformers/models/convnext/configuration_convnext.py +142 -0
  21. .venv/lib/python3.11/site-packages/transformers/models/convnext/feature_extraction_convnext.py +36 -0
  22. .venv/lib/python3.11/site-packages/transformers/models/convnext/image_processing_convnext.py +323 -0
  23. .venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_convnext.py +551 -0
  24. .venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_tf_convnext.py +669 -0
  25. .venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__init__.py +27 -0
  26. .venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/__init__.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/configuration_decision_transformer.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/modeling_decision_transformer.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/transformers/models/decision_transformer/modeling_decision_transformer.py +963 -0
  30. .venv/lib/python3.11/site-packages/transformers/models/focalnet/__pycache__/__init__.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__init__.py +26 -0
  32. .venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/__init__.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/tokenization_gpt_sw3.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/tokenization_gpt_sw3.py +299 -0
  35. .venv/lib/python3.11/site-packages/transformers/models/musicgen/__init__.py +28 -0
  36. .venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/__init__.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/configuration_musicgen.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/processing_musicgen.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/transformers/models/musicgen/configuration_musicgen.py +247 -0
  40. .venv/lib/python3.11/site-packages/transformers/models/musicgen/modeling_musicgen.py +0 -0
  41. .venv/lib/python3.11/site-packages/transformers/models/musicgen/processing_musicgen.py +144 -0
  42. .venv/lib/python3.11/site-packages/transformers/models/olmoe/__init__.py +27 -0
  43. .venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/__init__.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/configuration_olmoe.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/modeling_olmoe.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/transformers/models/olmoe/configuration_olmoe.py +182 -0
  47. .venv/lib/python3.11/site-packages/transformers/models/olmoe/modeling_olmoe.py +1299 -0
  48. .venv/lib/python3.11/site-packages/transformers/models/pegasus/__init__.py +31 -0
  49. .venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/__init__.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/configuration_pegasus.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/transformers/models/barthez/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .tokenization_barthez import *
22
+ from .tokenization_barthez_fast import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
.venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (777 Bytes). View file
 
.venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/tokenization_barthez.cpython-311.pyc ADDED
Binary file (14.9 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/barthez/__pycache__/tokenization_barthez_fast.cpython-311.pyc ADDED
Binary file (8.97 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/barthez/tokenization_barthez.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 Ecole Polytechnique and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License
15
+ """Tokenization classes for the BARThez model."""
16
+
17
+ import os
18
+ from shutil import copyfile
19
+ from typing import Any, Dict, List, Optional, Tuple
20
+
21
+ import sentencepiece as spm
22
+
23
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
24
+ from ...utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
30
+
31
+
32
+ SPIECE_UNDERLINE = "▁"
33
+
34
+ # TODO this class is useless. This is the most standard sentencpiece model. Let's find which one is closest and nuke this.
35
+
36
+
37
+ class BarthezTokenizer(PreTrainedTokenizer):
38
+ """
39
+ Adapted from [`CamembertTokenizer`] and [`BartTokenizer`]. Construct a BARThez tokenizer. Based on
40
+ [SentencePiece](https://github.com/google/sentencepiece).
41
+
42
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
43
+ this superclass for more information regarding those methods.
44
+
45
+ Args:
46
+ vocab_file (`str`):
47
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
48
+ contains the vocabulary necessary to instantiate a tokenizer.
49
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
50
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
51
+
52
+ <Tip>
53
+
54
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
55
+ sequence. The token used is the `cls_token`.
56
+
57
+ </Tip>
58
+
59
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
60
+ The end of sequence token.
61
+
62
+ <Tip>
63
+
64
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
65
+ The token used is the `sep_token`.
66
+
67
+ </Tip>
68
+
69
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
70
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
71
+ sequence classification or for a text and a question for question answering. It is also used as the last
72
+ token of a sequence built with special tokens.
73
+ cls_token (`str`, *optional*, defaults to `"<s>"`):
74
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
75
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
76
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
77
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
78
+ token instead.
79
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
80
+ The token used for padding, for example when batching sequences of different lengths.
81
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
82
+ The token used for masking values. This is the token used when training this model with masked language
83
+ modeling. This is the token which the model will try to predict.
84
+ sp_model_kwargs (`dict`, *optional*):
85
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
86
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
87
+ to set:
88
+
89
+ - `enable_sampling`: Enable subword regularization.
90
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
91
+
92
+ - `nbest_size = {0,1}`: No sampling is performed.
93
+ - `nbest_size > 1`: samples from the nbest_size results.
94
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
95
+ using forward-filtering-and-backward-sampling algorithm.
96
+
97
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
98
+ BPE-dropout.
99
+
100
+ Attributes:
101
+ sp_model (`SentencePieceProcessor`):
102
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
103
+ """
104
+
105
+ vocab_files_names = VOCAB_FILES_NAMES
106
+ model_input_names = ["input_ids", "attention_mask"]
107
+
108
+ def __init__(
109
+ self,
110
+ vocab_file,
111
+ bos_token="<s>",
112
+ eos_token="</s>",
113
+ sep_token="</s>",
114
+ cls_token="<s>",
115
+ unk_token="<unk>",
116
+ pad_token="<pad>",
117
+ mask_token="<mask>",
118
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
119
+ **kwargs,
120
+ ) -> None:
121
+ # Mask token behave like a normal word, i.e. include the space before it. Will have normalized=False by default this way
122
+ mask_token = AddedToken(mask_token, lstrip=True, special=True) if isinstance(mask_token, str) else mask_token
123
+
124
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
125
+
126
+ self.vocab_file = vocab_file
127
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
128
+ self.sp_model.Load(str(vocab_file))
129
+ super().__init__(
130
+ bos_token=bos_token,
131
+ eos_token=eos_token,
132
+ unk_token=unk_token,
133
+ sep_token=sep_token,
134
+ cls_token=cls_token,
135
+ pad_token=pad_token,
136
+ mask_token=mask_token,
137
+ sp_model_kwargs=self.sp_model_kwargs,
138
+ **kwargs,
139
+ )
140
+
141
+ def build_inputs_with_special_tokens(
142
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
143
+ ) -> List[int]:
144
+ """
145
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
146
+ adding special tokens. A BARThez sequence has the following format:
147
+
148
+ - single sequence: `<s> X </s>`
149
+ - pair of sequences: `<s> A </s></s> B </s>`
150
+
151
+ Args:
152
+ token_ids_0 (`List[int]`):
153
+ List of IDs to which the special tokens will be added.
154
+ token_ids_1 (`List[int]`, *optional*):
155
+ Optional second list of IDs for sequence pairs.
156
+
157
+ Returns:
158
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
159
+ """
160
+
161
+ if token_ids_1 is None:
162
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
163
+ cls = [self.cls_token_id]
164
+ sep = [self.sep_token_id]
165
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
166
+
167
+ def get_special_tokens_mask(
168
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
169
+ ) -> List[int]:
170
+ """
171
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
172
+ special tokens using the tokenizer `prepare_for_model` method.
173
+
174
+ Args:
175
+ token_ids_0 (`List[int]`):
176
+ List of IDs.
177
+ token_ids_1 (`List[int]`, *optional*):
178
+ Optional second list of IDs for sequence pairs.
179
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
180
+ Whether or not the token list is already formatted with special tokens for the model.
181
+
182
+ Returns:
183
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
184
+ """
185
+ if already_has_special_tokens:
186
+ return super().get_special_tokens_mask(
187
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
188
+ )
189
+
190
+ if token_ids_1 is None:
191
+ return [1] + ([0] * len(token_ids_0)) + [1]
192
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
193
+
194
+ def create_token_type_ids_from_sequences(
195
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
196
+ ) -> List[int]:
197
+ """
198
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task.
199
+
200
+ Args:
201
+ token_ids_0 (`List[int]`):
202
+ List of IDs.
203
+ token_ids_1 (`List[int]`, *optional*):
204
+ Optional second list of IDs for sequence pairs.
205
+
206
+ Returns:
207
+ `List[int]`: List of zeros.
208
+ """
209
+ sep = [self.sep_token_id]
210
+ cls = [self.cls_token_id]
211
+
212
+ if token_ids_1 is None:
213
+ return len(cls + token_ids_0 + sep) * [0]
214
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
215
+
216
+ @property
217
+ def vocab_size(self):
218
+ return len(self.sp_model)
219
+
220
+ def get_vocab(self):
221
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
222
+ vocab.update(self.added_tokens_encoder)
223
+ return vocab
224
+
225
+ def _tokenize(self, text: str) -> List[str]:
226
+ return self.sp_model.encode(text, out_type=str)
227
+
228
+ def _convert_token_to_id(self, token):
229
+ """Converts a token (str) in an id using the vocab."""
230
+ return self.sp_model.PieceToId(token)
231
+
232
+ def _convert_id_to_token(self, index):
233
+ """Converts an index (integer) in a token (str) using the vocab."""
234
+ return self.sp_model.IdToPiece(index)
235
+
236
+ # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.convert_tokens_to_string
237
+ def convert_tokens_to_string(self, tokens):
238
+ """Converts a sequence of tokens (string) in a single string."""
239
+ current_sub_tokens = []
240
+ out_string = ""
241
+ prev_is_special = False
242
+ for token in tokens:
243
+ # make sure that special tokens are not decoded using sentencepiece model
244
+ if token in self.all_special_tokens:
245
+ if not prev_is_special:
246
+ out_string += " "
247
+ out_string += self.sp_model.decode(current_sub_tokens) + token
248
+ prev_is_special = True
249
+ current_sub_tokens = []
250
+ else:
251
+ current_sub_tokens.append(token)
252
+ prev_is_special = False
253
+ out_string += self.sp_model.decode(current_sub_tokens)
254
+ return out_string.strip()
255
+
256
+ def __getstate__(self):
257
+ state = self.__dict__.copy()
258
+ state["sp_model"] = None
259
+ return state
260
+
261
+ def __setstate__(self, d):
262
+ self.__dict__ = d
263
+
264
+ # for backward compatibility
265
+ if not hasattr(self, "sp_model_kwargs"):
266
+ self.sp_model_kwargs = {}
267
+
268
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
269
+ self.sp_model.Load(self.vocab_file)
270
+
271
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
272
+ if not os.path.isdir(save_directory):
273
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
274
+ return
275
+ out_vocab_file = os.path.join(
276
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
277
+ )
278
+
279
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
280
+ copyfile(self.vocab_file, out_vocab_file)
281
+ elif not os.path.isfile(self.vocab_file):
282
+ with open(out_vocab_file, "wb") as fi:
283
+ content_spiece_model = self.sp_model.serialized_model_proto()
284
+ fi.write(content_spiece_model)
285
+
286
+ return (out_vocab_file,)
287
+
288
+
289
+ __all__ = ["BarthezTokenizer"]
.venv/lib/python3.11/site-packages/transformers/models/barthez/tokenization_barthez_fast.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 Ecole Polytechnique and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License
15
+ """Tokenization classes for the BARThez model."""
16
+
17
+ import os
18
+ from shutil import copyfile
19
+ from typing import List, Optional, Tuple
20
+
21
+ from ...tokenization_utils import AddedToken
22
+ from ...tokenization_utils_fast import PreTrainedTokenizerFast
23
+ from ...utils import is_sentencepiece_available, logging
24
+
25
+
26
+ if is_sentencepiece_available():
27
+ from .tokenization_barthez import BarthezTokenizer
28
+ else:
29
+ BarthezTokenizer = None
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
34
+
35
+
36
+ SPIECE_UNDERLINE = "▁"
37
+
38
+
39
+ class BarthezTokenizerFast(PreTrainedTokenizerFast):
40
+ """
41
+ Adapted from [`CamembertTokenizer`] and [`BartTokenizer`]. Construct a "fast" BARThez tokenizer. Based on
42
+ [SentencePiece](https://github.com/google/sentencepiece).
43
+
44
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
45
+ refer to this superclass for more information regarding those methods.
46
+
47
+ Args:
48
+ vocab_file (`str`):
49
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
50
+ contains the vocabulary necessary to instantiate a tokenizer.
51
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
52
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
53
+
54
+ <Tip>
55
+
56
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
57
+ sequence. The token used is the `cls_token`.
58
+
59
+ </Tip>
60
+
61
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
62
+ The end of sequence token.
63
+
64
+ <Tip>
65
+
66
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
67
+ The token used is the `sep_token`.
68
+
69
+ </Tip>
70
+
71
+ sep_token (`str`, *optional*, defaults to `"</s>"`):
72
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
73
+ sequence classification or for a text and a question for question answering. It is also used as the last
74
+ token of a sequence built with special tokens.
75
+ cls_token (`str`, *optional*, defaults to `"<s>"`):
76
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
77
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
78
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
79
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
80
+ token instead.
81
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
82
+ The token used for padding, for example when batching sequences of different lengths.
83
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
84
+ The token used for masking values. This is the token used when training this model with masked language
85
+ modeling. This is the token which the model will try to predict.
86
+ additional_special_tokens (`List[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
87
+ Additional special tokens used by the tokenizer.
88
+ """
89
+
90
+ vocab_files_names = VOCAB_FILES_NAMES
91
+ model_input_names = ["input_ids", "attention_mask"]
92
+ slow_tokenizer_class = BarthezTokenizer
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_file=None,
97
+ tokenizer_file=None,
98
+ bos_token="<s>",
99
+ eos_token="</s>",
100
+ sep_token="</s>",
101
+ cls_token="<s>",
102
+ unk_token="<unk>",
103
+ pad_token="<pad>",
104
+ mask_token="<mask>",
105
+ **kwargs,
106
+ ):
107
+ # Mask token behave like a normal word, i.e. include the space before it
108
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
109
+
110
+ super().__init__(
111
+ vocab_file,
112
+ tokenizer_file=tokenizer_file,
113
+ bos_token=bos_token,
114
+ eos_token=eos_token,
115
+ unk_token=unk_token,
116
+ sep_token=sep_token,
117
+ cls_token=cls_token,
118
+ pad_token=pad_token,
119
+ mask_token=mask_token,
120
+ **kwargs,
121
+ )
122
+
123
+ self.vocab_file = vocab_file
124
+
125
+ @property
126
+ def can_save_slow_tokenizer(self) -> bool:
127
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
128
+
129
+ def build_inputs_with_special_tokens(
130
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
131
+ ) -> List[int]:
132
+ """
133
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
134
+ adding special tokens. A BARThez sequence has the following format:
135
+
136
+ - single sequence: `<s> X </s>`
137
+ - pair of sequences: `<s> A </s></s> B </s>`
138
+
139
+ Args:
140
+ token_ids_0 (`List[int]`):
141
+ List of IDs to which the special tokens will be added.
142
+ token_ids_1 (`List[int]`, *optional*):
143
+ Optional second list of IDs for sequence pairs.
144
+
145
+ Returns:
146
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
147
+ """
148
+
149
+ if token_ids_1 is None:
150
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
151
+ cls = [self.cls_token_id]
152
+ sep = [self.sep_token_id]
153
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
154
+
155
+ def create_token_type_ids_from_sequences(
156
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
157
+ ) -> List[int]:
158
+ """
159
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task.
160
+
161
+ Args:
162
+ token_ids_0 (`List[int]`):
163
+ List of IDs.
164
+ token_ids_1 (`List[int]`, *optional*):
165
+ Optional second list of IDs for sequence pairs.
166
+
167
+ Returns:
168
+ `List[int]`: List of zeros.
169
+ """
170
+ sep = [self.sep_token_id]
171
+ cls = [self.cls_token_id]
172
+
173
+ if token_ids_1 is None:
174
+ return len(cls + token_ids_0 + sep) * [0]
175
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
176
+
177
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
178
+ if not self.can_save_slow_tokenizer:
179
+ raise ValueError(
180
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
181
+ "tokenizer."
182
+ )
183
+
184
+ if not os.path.isdir(save_directory):
185
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
186
+ return
187
+ out_vocab_file = os.path.join(
188
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
189
+ )
190
+
191
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
192
+ copyfile(self.vocab_file, out_vocab_file)
193
+
194
+ return (out_vocab_file,)
195
+
196
+
197
+ __all__ = ["BarthezTokenizerFast"]
.venv/lib/python3.11/site-packages/transformers/models/beit/configuration_beit.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """BEiT model configuration"""
16
+
17
+ import warnings
18
+ from collections import OrderedDict
19
+ from typing import Mapping
20
+
21
+ from packaging import version
22
+
23
+ from ...configuration_utils import PretrainedConfig
24
+ from ...onnx import OnnxConfig
25
+ from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
26
+
27
+
28
+ class BeitConfig(BackboneConfigMixin, PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`BeitModel`]. It is used to instantiate an BEiT
31
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
32
+ defaults will yield a similar configuration to that of the BEiT
33
+ [microsoft/beit-base-patch16-224-pt22k](https://huggingface.co/microsoft/beit-base-patch16-224-pt22k) architecture.
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 8192):
37
+ Vocabulary size of the BEiT model. Defines the number of different image tokens that can be used during
38
+ pre-training.
39
+ hidden_size (`int`, *optional*, defaults to 768):
40
+ Dimensionality of the encoder layers and the pooler layer.
41
+ num_hidden_layers (`int`, *optional*, defaults to 12):
42
+ Number of hidden layers in the Transformer encoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 12):
44
+ Number of attention heads for each attention layer in the Transformer encoder.
45
+ intermediate_size (`int`, *optional*, defaults to 3072):
46
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
47
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
48
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
49
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
50
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
51
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
52
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
53
+ The dropout ratio for the attention probabilities.
54
+ initializer_range (`float`, *optional*, defaults to 0.02):
55
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
56
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
57
+ The epsilon used by the layer normalization layers.
58
+ image_size (`int`, *optional*, defaults to 224):
59
+ The size (resolution) of each image.
60
+ patch_size (`int`, *optional*, defaults to 16):
61
+ The size (resolution) of each patch.
62
+ num_channels (`int`, *optional*, defaults to 3):
63
+ The number of input channels.
64
+ use_mask_token (`bool`, *optional*, defaults to `False`):
65
+ Whether to use a mask token for masked image modeling.
66
+ use_absolute_position_embeddings (`bool`, *optional*, defaults to `False`):
67
+ Whether to use BERT-style absolute position embeddings.
68
+ use_relative_position_bias (`bool`, *optional*, defaults to `False`):
69
+ Whether to use T5-style relative position embeddings in the self-attention layers.
70
+ use_shared_relative_position_bias (`bool`, *optional*, defaults to `False`):
71
+ Whether to use the same relative position embeddings across all self-attention layers of the Transformer.
72
+ layer_scale_init_value (`float`, *optional*, defaults to 0.1):
73
+ Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.
74
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
75
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
76
+ use_mean_pooling (`bool`, *optional*, defaults to `True`):
77
+ Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
78
+ CLS token, before applying the classification head.
79
+ pool_scales (`Tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`):
80
+ Pooling scales used in Pooling Pyramid Module applied on the last feature map.
81
+ use_auxiliary_head (`bool`, *optional*, defaults to `True`):
82
+ Whether to use an auxiliary head during training.
83
+ auxiliary_loss_weight (`float`, *optional*, defaults to 0.4):
84
+ Weight of the cross-entropy loss of the auxiliary head.
85
+ auxiliary_channels (`int`, *optional*, defaults to 256):
86
+ Number of channels to use in the auxiliary head.
87
+ auxiliary_num_convs (`int`, *optional*, defaults to 1):
88
+ Number of convolutional layers to use in the auxiliary head.
89
+ auxiliary_concat_input (`bool`, *optional*, defaults to `False`):
90
+ Whether to concatenate the output of the auxiliary head with the input before the classification layer.
91
+ semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
92
+ The index that is ignored by the loss function of the semantic segmentation model.
93
+ out_features (`List[str]`, *optional*):
94
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
95
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
96
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
97
+ same order as defined in the `stage_names` attribute.
98
+ out_indices (`List[int]`, *optional*):
99
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
100
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
101
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
102
+ same order as defined in the `stage_names` attribute.
103
+ add_fpn (`bool`, *optional*, defaults to `False`):
104
+ Whether to add a FPN as part of the backbone. Only relevant for [`BeitBackbone`].
105
+ reshape_hidden_states (`bool`, *optional*, defaults to `True`):
106
+ Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
107
+ case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
108
+ seq_len, hidden_size)`. Only relevant for [`BeitBackbone`].
109
+
110
+ Example:
111
+
112
+ ```python
113
+ >>> from transformers import BeitConfig, BeitModel
114
+
115
+ >>> # Initializing a BEiT beit-base-patch16-224-pt22k style configuration
116
+ >>> configuration = BeitConfig()
117
+
118
+ >>> # Initializing a model (with random weights) from the beit-base-patch16-224-pt22k style configuration
119
+ >>> model = BeitModel(configuration)
120
+
121
+ >>> # Accessing the model configuration
122
+ >>> configuration = model.config
123
+ ```"""
124
+
125
+ model_type = "beit"
126
+
127
+ def __init__(
128
+ self,
129
+ vocab_size=8192,
130
+ hidden_size=768,
131
+ num_hidden_layers=12,
132
+ num_attention_heads=12,
133
+ intermediate_size=3072,
134
+ hidden_act="gelu",
135
+ hidden_dropout_prob=0.0,
136
+ attention_probs_dropout_prob=0.0,
137
+ initializer_range=0.02,
138
+ layer_norm_eps=1e-12,
139
+ image_size=224,
140
+ patch_size=16,
141
+ num_channels=3,
142
+ use_mask_token=False,
143
+ use_absolute_position_embeddings=False,
144
+ use_relative_position_bias=False,
145
+ use_shared_relative_position_bias=False,
146
+ layer_scale_init_value=0.1,
147
+ drop_path_rate=0.1,
148
+ use_mean_pooling=True,
149
+ pool_scales=[1, 2, 3, 6],
150
+ use_auxiliary_head=True,
151
+ auxiliary_loss_weight=0.4,
152
+ auxiliary_channels=256,
153
+ auxiliary_num_convs=1,
154
+ auxiliary_concat_input=False,
155
+ semantic_loss_ignore_index=255,
156
+ out_features=None,
157
+ out_indices=None,
158
+ add_fpn=False,
159
+ reshape_hidden_states=True,
160
+ **kwargs,
161
+ ):
162
+ super().__init__(**kwargs)
163
+
164
+ self.vocab_size = vocab_size
165
+ self.hidden_size = hidden_size
166
+ self.num_hidden_layers = num_hidden_layers
167
+ self.num_attention_heads = num_attention_heads
168
+ self.intermediate_size = intermediate_size
169
+ self.hidden_act = hidden_act
170
+ self.hidden_dropout_prob = hidden_dropout_prob
171
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
172
+ self.initializer_range = initializer_range
173
+ self.layer_norm_eps = layer_norm_eps
174
+
175
+ self.image_size = image_size
176
+ self.patch_size = patch_size
177
+ self.num_channels = num_channels
178
+ self.use_mask_token = use_mask_token
179
+ self.use_absolute_position_embeddings = use_absolute_position_embeddings
180
+ self.use_relative_position_bias = use_relative_position_bias
181
+ self.use_shared_relative_position_bias = use_shared_relative_position_bias
182
+ self.layer_scale_init_value = layer_scale_init_value
183
+ self.drop_path_rate = drop_path_rate
184
+ self.use_mean_pooling = use_mean_pooling
185
+ # decode head attributes (semantic segmentation)
186
+ self.pool_scales = pool_scales
187
+ # auxiliary head attributes (semantic segmentation)
188
+ self.use_auxiliary_head = use_auxiliary_head
189
+ self.auxiliary_loss_weight = auxiliary_loss_weight
190
+ self.auxiliary_channels = auxiliary_channels
191
+ self.auxiliary_num_convs = auxiliary_num_convs
192
+ self.auxiliary_concat_input = auxiliary_concat_input
193
+ self.semantic_loss_ignore_index = semantic_loss_ignore_index
194
+
195
+ # handle backwards compatibility
196
+ if "segmentation_indices" in kwargs:
197
+ warnings.warn(
198
+ "The `segmentation_indices` argument is deprecated and will be removed in a future version, use `out_indices` instead.",
199
+ FutureWarning,
200
+ )
201
+ out_indices = kwargs.pop("segmentation_indices")
202
+
203
+ # backbone attributes
204
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, self.num_hidden_layers + 1)]
205
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
206
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
207
+ )
208
+ self.add_fpn = add_fpn
209
+ self.reshape_hidden_states = reshape_hidden_states
210
+
211
+
212
+ # Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig
213
+ class BeitOnnxConfig(OnnxConfig):
214
+ torch_onnx_minimum_version = version.parse("1.11")
215
+
216
+ @property
217
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
218
+ return OrderedDict(
219
+ [
220
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
221
+ ]
222
+ )
223
+
224
+ @property
225
+ def atol_for_validation(self) -> float:
226
+ return 1e-4
227
+
228
+
229
+ __all__ = ["BeitConfig", "BeitOnnxConfig"]
.venv/lib/python3.11/site-packages/transformers/models/beit/feature_extraction_beit.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Feature extractor class for BEiT."""
16
+
17
+ import warnings
18
+
19
+ from ...utils import logging
20
+ from .image_processing_beit import BeitImageProcessor
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class BeitFeatureExtractor(BeitImageProcessor):
27
+ def __init__(self, *args, **kwargs) -> None:
28
+ warnings.warn(
29
+ "The class BeitFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
30
+ " use BeitImageProcessor instead.",
31
+ FutureWarning,
32
+ )
33
+ super().__init__(*args, **kwargs)
34
+
35
+
36
+ __all__ = ["BeitFeatureExtractor"]
.venv/lib/python3.11/site-packages/transformers/models/beit/image_processing_beit.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for Beit."""
16
+
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...image_processing_utils import INIT_SERVICE_KWARGS, BaseImageProcessor, BatchFeature, get_size_dict
22
+ from ...image_transforms import resize, to_channel_dimension_format
23
+ from ...image_utils import (
24
+ IMAGENET_STANDARD_MEAN,
25
+ IMAGENET_STANDARD_STD,
26
+ ChannelDimension,
27
+ ImageInput,
28
+ PILImageResampling,
29
+ infer_channel_dimension_format,
30
+ is_scaled_image,
31
+ make_list_of_images,
32
+ to_numpy_array,
33
+ valid_images,
34
+ validate_preprocess_arguments,
35
+ )
36
+ from ...utils import (
37
+ TensorType,
38
+ filter_out_non_signature_kwargs,
39
+ is_torch_available,
40
+ is_torch_tensor,
41
+ is_vision_available,
42
+ logging,
43
+ )
44
+ from ...utils.deprecation import deprecate_kwarg
45
+
46
+
47
+ if is_vision_available():
48
+ import PIL
49
+
50
+ if is_torch_available():
51
+ import torch
52
+
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+
57
+ class BeitImageProcessor(BaseImageProcessor):
58
+ r"""
59
+ Constructs a BEiT image processor.
60
+
61
+ Args:
62
+ do_resize (`bool`, *optional*, defaults to `True`):
63
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
64
+ `do_resize` parameter in the `preprocess` method.
65
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
66
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
67
+ method.
68
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
69
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
70
+ `preprocess` method.
71
+ do_center_crop (`bool`, *optional*, defaults to `True`):
72
+ Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
73
+ is padded with 0's and then center cropped. Can be overridden by the `do_center_crop` parameter in the
74
+ `preprocess` method.
75
+ crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
76
+ Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
77
+ Can be overridden by the `crop_size` parameter in the `preprocess` method.
78
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
79
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
80
+ `preprocess` method.
81
+ do_rescale (`bool`, *optional*, defaults to `True`):
82
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
83
+ parameter in the `preprocess` method.
84
+ do_normalize (`bool`, *optional*, defaults to `True`):
85
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
86
+ method.
87
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
88
+ The mean to use if normalizing the image. This is a float or list of floats of length of the number of
89
+ channels of the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
90
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
91
+ The standard deviation to use if normalizing the image. This is a float or list of floats of length of the
92
+ number of channels of the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
93
+ do_reduce_labels (`bool`, *optional*, defaults to `False`):
94
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
95
+ used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
96
+ background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
97
+ `preprocess` method.
98
+ """
99
+
100
+ model_input_names = ["pixel_values"]
101
+
102
+ @deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0")
103
+ @filter_out_non_signature_kwargs(extra=INIT_SERVICE_KWARGS)
104
+ def __init__(
105
+ self,
106
+ do_resize: bool = True,
107
+ size: Dict[str, int] = None,
108
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
109
+ do_center_crop: bool = True,
110
+ crop_size: Dict[str, int] = None,
111
+ rescale_factor: Union[int, float] = 1 / 255,
112
+ do_rescale: bool = True,
113
+ do_normalize: bool = True,
114
+ image_mean: Optional[Union[float, List[float]]] = None,
115
+ image_std: Optional[Union[float, List[float]]] = None,
116
+ do_reduce_labels: bool = False,
117
+ **kwargs,
118
+ ) -> None:
119
+ super().__init__(**kwargs)
120
+ size = size if size is not None else {"height": 256, "width": 256}
121
+ size = get_size_dict(size)
122
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
123
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
124
+ self.do_resize = do_resize
125
+ self.size = size
126
+ self.resample = resample
127
+ self.do_center_crop = do_center_crop
128
+ self.crop_size = crop_size
129
+ self.do_rescale = do_rescale
130
+ self.rescale_factor = rescale_factor
131
+ self.do_normalize = do_normalize
132
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
133
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
134
+ self.do_reduce_labels = do_reduce_labels
135
+
136
+ @classmethod
137
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
138
+ """
139
+ Overrides the `from_dict` method from the base class to save support of deprecated `reduce_labels` in old configs
140
+ """
141
+ image_processor_dict = image_processor_dict.copy()
142
+ if "reduce_labels" in image_processor_dict:
143
+ image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels")
144
+ return super().from_dict(image_processor_dict, **kwargs)
145
+
146
+ def resize(
147
+ self,
148
+ image: np.ndarray,
149
+ size: Dict[str, int],
150
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
151
+ data_format: Optional[Union[str, ChannelDimension]] = None,
152
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
153
+ **kwargs,
154
+ ) -> np.ndarray:
155
+ """
156
+ Resize an image to (size["height"], size["width"]).
157
+
158
+ Args:
159
+ image (`np.ndarray`):
160
+ Image to resize.
161
+ size (`Dict[str, int]`):
162
+ Size of the output image.
163
+ resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BICUBIC`):
164
+ Resampling filter to use when resiizing the image.
165
+ data_format (`str` or `ChannelDimension`, *optional*):
166
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
167
+ input_data_format (`str` or `ChannelDimension`, *optional*):
168
+ The channel dimension format of the input image. If not provided, it will be inferred.
169
+ """
170
+ size = get_size_dict(size, default_to_square=True, param_name="size")
171
+ if "height" not in size or "width" not in size:
172
+ raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}")
173
+ return resize(
174
+ image,
175
+ size=(size["height"], size["width"]),
176
+ resample=resample,
177
+ data_format=data_format,
178
+ input_data_format=input_data_format,
179
+ **kwargs,
180
+ )
181
+
182
+ def reduce_label(self, label: ImageInput) -> np.ndarray:
183
+ label = to_numpy_array(label)
184
+ # Avoid using underflow conversion
185
+ label[label == 0] = 255
186
+ label = label - 1
187
+ label[label == 254] = 255
188
+ return label
189
+
190
+ def _preprocess(
191
+ self,
192
+ image: ImageInput,
193
+ do_reduce_labels: bool = None,
194
+ do_resize: bool = None,
195
+ size: Dict[str, int] = None,
196
+ resample: PILImageResampling = None,
197
+ do_center_crop: bool = None,
198
+ crop_size: Dict[str, int] = None,
199
+ do_rescale: bool = None,
200
+ rescale_factor: float = None,
201
+ do_normalize: bool = None,
202
+ image_mean: Optional[Union[float, List[float]]] = None,
203
+ image_std: Optional[Union[float, List[float]]] = None,
204
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
205
+ ):
206
+ if do_reduce_labels:
207
+ image = self.reduce_label(image)
208
+
209
+ if do_resize:
210
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
211
+
212
+ if do_center_crop:
213
+ image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
214
+
215
+ if do_rescale:
216
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
217
+
218
+ if do_normalize:
219
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
220
+
221
+ return image
222
+
223
+ def _preprocess_image(
224
+ self,
225
+ image: ImageInput,
226
+ do_resize: bool = None,
227
+ size: Dict[str, int] = None,
228
+ resample: PILImageResampling = None,
229
+ do_center_crop: bool = None,
230
+ crop_size: Dict[str, int] = None,
231
+ do_rescale: bool = None,
232
+ rescale_factor: float = None,
233
+ do_normalize: bool = None,
234
+ image_mean: Optional[Union[float, List[float]]] = None,
235
+ image_std: Optional[Union[float, List[float]]] = None,
236
+ data_format: Optional[Union[str, ChannelDimension]] = None,
237
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
238
+ ) -> np.ndarray:
239
+ """Preprocesses a single image."""
240
+ # All transformations expect numpy arrays.
241
+ image = to_numpy_array(image)
242
+ if do_rescale and is_scaled_image(image):
243
+ logger.warning_once(
244
+ "It looks like you are trying to rescale already rescaled images. If the input"
245
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
246
+ )
247
+ if input_data_format is None:
248
+ input_data_format = infer_channel_dimension_format(image)
249
+ image = self._preprocess(
250
+ image,
251
+ do_reduce_labels=False,
252
+ do_resize=do_resize,
253
+ size=size,
254
+ resample=resample,
255
+ do_center_crop=do_center_crop,
256
+ crop_size=crop_size,
257
+ do_rescale=do_rescale,
258
+ rescale_factor=rescale_factor,
259
+ do_normalize=do_normalize,
260
+ image_mean=image_mean,
261
+ image_std=image_std,
262
+ input_data_format=input_data_format,
263
+ )
264
+ if data_format is not None:
265
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
266
+ return image
267
+
268
+ def _preprocess_segmentation_map(
269
+ self,
270
+ segmentation_map: ImageInput,
271
+ do_resize: bool = None,
272
+ size: Dict[str, int] = None,
273
+ resample: PILImageResampling = None,
274
+ do_center_crop: bool = None,
275
+ crop_size: Dict[str, int] = None,
276
+ do_reduce_labels: bool = None,
277
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
278
+ ):
279
+ """Preprocesses a single segmentation map."""
280
+ # All transformations expect numpy arrays.
281
+ segmentation_map = to_numpy_array(segmentation_map)
282
+ # Add an axis to the segmentation maps for transformations.
283
+ if segmentation_map.ndim == 2:
284
+ segmentation_map = segmentation_map[None, ...]
285
+ added_dimension = True
286
+ input_data_format = ChannelDimension.FIRST
287
+ else:
288
+ added_dimension = False
289
+ if input_data_format is None:
290
+ input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
291
+ segmentation_map = self._preprocess(
292
+ image=segmentation_map,
293
+ do_reduce_labels=do_reduce_labels,
294
+ do_resize=do_resize,
295
+ resample=resample,
296
+ size=size,
297
+ do_center_crop=do_center_crop,
298
+ crop_size=crop_size,
299
+ do_normalize=False,
300
+ do_rescale=False,
301
+ input_data_format=ChannelDimension.FIRST,
302
+ )
303
+ # Remove extra axis if added
304
+ if added_dimension:
305
+ segmentation_map = np.squeeze(segmentation_map, axis=0)
306
+ segmentation_map = segmentation_map.astype(np.int64)
307
+ return segmentation_map
308
+
309
+ def __call__(self, images, segmentation_maps=None, **kwargs):
310
+ # Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both
311
+ # be passed in as positional arguments.
312
+ return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
313
+
314
+ @deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0")
315
+ @filter_out_non_signature_kwargs()
316
+ def preprocess(
317
+ self,
318
+ images: ImageInput,
319
+ segmentation_maps: Optional[ImageInput] = None,
320
+ do_resize: bool = None,
321
+ size: Dict[str, int] = None,
322
+ resample: PILImageResampling = None,
323
+ do_center_crop: bool = None,
324
+ crop_size: Dict[str, int] = None,
325
+ do_rescale: bool = None,
326
+ rescale_factor: float = None,
327
+ do_normalize: bool = None,
328
+ image_mean: Optional[Union[float, List[float]]] = None,
329
+ image_std: Optional[Union[float, List[float]]] = None,
330
+ do_reduce_labels: Optional[bool] = None,
331
+ return_tensors: Optional[Union[str, TensorType]] = None,
332
+ data_format: ChannelDimension = ChannelDimension.FIRST,
333
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
334
+ ) -> PIL.Image.Image:
335
+ """
336
+ Preprocess an image or batch of images.
337
+
338
+ Args:
339
+ images (`ImageInput`):
340
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
341
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
342
+ segmentation_maps (`ImageInput`, *optional*)
343
+ Segmentation maps to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
344
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
345
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
346
+ Whether to resize the image.
347
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
348
+ Size of the image after resizing.
349
+ resample (`int`, *optional*, defaults to `self.resample`):
350
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
351
+ has an effect if `do_resize` is set to `True`.
352
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
353
+ Whether to center crop the image.
354
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
355
+ Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
356
+ padded with zeros and then cropped
357
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
358
+ Whether to rescale the image values between [0 - 1].
359
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
360
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
361
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
362
+ Whether to normalize the image.
363
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
364
+ Image mean.
365
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
366
+ Image standard deviation.
367
+ do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
368
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
369
+ is used for background, and background itself is not included in all classes of a dataset (e.g.
370
+ ADE20k). The background label will be replaced by 255.
371
+ return_tensors (`str` or `TensorType`, *optional*):
372
+ The type of tensors to return. Can be one of:
373
+ - Unset: Return a list of `np.ndarray`.
374
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
375
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
376
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
377
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
378
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
379
+ The channel dimension format for the output image. Can be one of:
380
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
381
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
382
+ - Unset: Use the channel dimension format of the input image.
383
+ input_data_format (`ChannelDimension` or `str`, *optional*):
384
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
385
+ from the input image. Can be one of:
386
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
387
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
388
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
389
+ """
390
+ do_resize = do_resize if do_resize is not None else self.do_resize
391
+ size = size if size is not None else self.size
392
+ size = get_size_dict(size, default_to_square=True, param_name="size")
393
+ resample = resample if resample is not None else self.resample
394
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
395
+ crop_size = crop_size if crop_size is not None else self.crop_size
396
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
397
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
398
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
399
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
400
+ image_mean = image_mean if image_mean is not None else self.image_mean
401
+ image_std = image_std if image_std is not None else self.image_std
402
+ do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
403
+
404
+ images = make_list_of_images(images)
405
+
406
+ if segmentation_maps is not None:
407
+ segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
408
+
409
+ if segmentation_maps is not None and not valid_images(segmentation_maps):
410
+ raise ValueError(
411
+ "Invalid segmentation_maps type. Must be of type PIL.Image.Image, numpy.ndarray, "
412
+ "torch.Tensor, tf.Tensor or jax.ndarray."
413
+ )
414
+ if not valid_images(images):
415
+ raise ValueError(
416
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
417
+ "torch.Tensor, tf.Tensor or jax.ndarray."
418
+ )
419
+
420
+ validate_preprocess_arguments(
421
+ do_rescale=do_rescale,
422
+ rescale_factor=rescale_factor,
423
+ do_normalize=do_normalize,
424
+ image_mean=image_mean,
425
+ image_std=image_std,
426
+ do_center_crop=do_center_crop,
427
+ crop_size=crop_size,
428
+ do_resize=do_resize,
429
+ size=size,
430
+ resample=resample,
431
+ )
432
+
433
+ images = [
434
+ self._preprocess_image(
435
+ image=img,
436
+ do_resize=do_resize,
437
+ do_center_crop=do_center_crop,
438
+ do_rescale=do_rescale,
439
+ do_normalize=do_normalize,
440
+ resample=resample,
441
+ size=size,
442
+ rescale_factor=rescale_factor,
443
+ crop_size=crop_size,
444
+ image_mean=image_mean,
445
+ image_std=image_std,
446
+ data_format=data_format,
447
+ input_data_format=input_data_format,
448
+ )
449
+ for img in images
450
+ ]
451
+
452
+ data = {"pixel_values": images}
453
+
454
+ if segmentation_maps is not None:
455
+ segmentation_maps = [
456
+ self._preprocess_segmentation_map(
457
+ segmentation_map=segmentation_map,
458
+ do_reduce_labels=do_reduce_labels,
459
+ do_resize=do_resize,
460
+ resample=resample,
461
+ size=size,
462
+ do_center_crop=do_center_crop,
463
+ crop_size=crop_size,
464
+ )
465
+ for segmentation_map in segmentation_maps
466
+ ]
467
+ data["labels"] = segmentation_maps
468
+
469
+ return BatchFeature(data=data, tensor_type=return_tensors)
470
+
471
+ def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
472
+ """
473
+ Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
474
+
475
+ Args:
476
+ outputs ([`BeitForSemanticSegmentation`]):
477
+ Raw outputs of the model.
478
+ target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
479
+ List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
480
+ predictions will not be resized.
481
+
482
+ Returns:
483
+ semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
484
+ segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
485
+ specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
486
+ """
487
+ # TODO: add support for other frameworks
488
+ logits = outputs.logits
489
+
490
+ # Resize logits and compute semantic segmentation maps
491
+ if target_sizes is not None:
492
+ if len(logits) != len(target_sizes):
493
+ raise ValueError(
494
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
495
+ )
496
+
497
+ if is_torch_tensor(target_sizes):
498
+ target_sizes = target_sizes.numpy()
499
+
500
+ semantic_segmentation = []
501
+
502
+ for idx in range(len(logits)):
503
+ resized_logits = torch.nn.functional.interpolate(
504
+ logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
505
+ )
506
+ semantic_map = resized_logits[0].argmax(dim=0)
507
+ semantic_segmentation.append(semantic_map)
508
+ else:
509
+ semantic_segmentation = logits.argmax(dim=1)
510
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
511
+
512
+ return semantic_segmentation
513
+
514
+
515
+ __all__ = ["BeitImageProcessor"]
.venv/lib/python3.11/site-packages/transformers/models/beit/modeling_flax_beit.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Microsoft Research and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from typing import Callable, List, Optional, Tuple
18
+
19
+ import flax
20
+ import flax.linen as nn
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
25
+ from flax.linen.attention import dot_product_attention_weights
26
+ from flax.traverse_util import flatten_dict, unflatten_dict
27
+
28
+ from ...modeling_flax_outputs import (
29
+ FlaxBaseModelOutput,
30
+ FlaxBaseModelOutputWithPooling,
31
+ FlaxMaskedLMOutput,
32
+ FlaxSequenceClassifierOutput,
33
+ )
34
+ from ...modeling_flax_utils import (
35
+ ACT2FN,
36
+ FlaxPreTrainedModel,
37
+ append_replace_return_docstrings,
38
+ overwrite_call_docstring,
39
+ )
40
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
41
+ from .configuration_beit import BeitConfig
42
+
43
+
44
+ @flax.struct.dataclass
45
+ class FlaxBeitModelOutputWithPooling(FlaxBaseModelOutputWithPooling):
46
+ """
47
+ Class for outputs of [`FlaxBeitModel`].
48
+
49
+ Args:
50
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
51
+ Sequence of hidden-states at the output of the last layer of the model.
52
+ pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
53
+ Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
54
+ *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
55
+ will be returned.
56
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
57
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
58
+ `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
59
+ the initial embedding outputs.
60
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
61
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
62
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
63
+ the self-attention heads.
64
+ """
65
+
66
+
67
+ BEIT_START_DOCSTRING = r"""
68
+
69
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
70
+ library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
71
+
72
+ This model is also a
73
+ [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
74
+ a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
75
+ behavior.
76
+
77
+ Finally, this model supports inherent JAX features such as:
78
+
79
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
80
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
81
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
82
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
83
+
84
+ Parameters:
85
+ config ([`BeitConfig`]): Model configuration class with all the parameters of the model.
86
+ Initializing with a config file does not load the weights associated with the model, only the
87
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
88
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
89
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
90
+ `jax.numpy.bfloat16` (on TPUs).
91
+
92
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
93
+ specified all the computation will be performed with the given `dtype`.
94
+
95
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
96
+ parameters.**
97
+
98
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
99
+ [`~FlaxPreTrainedModel.to_bf16`].
100
+ """
101
+
102
+ BEIT_INPUTS_DOCSTRING = r"""
103
+ Args:
104
+ pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
105
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
106
+ [`AutoImageProcessor.__call__`] for details.
107
+
108
+ output_attentions (`bool`, *optional*):
109
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
110
+ tensors for more detail.
111
+ output_hidden_states (`bool`, *optional*):
112
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
113
+ more detail.
114
+ return_dict (`bool`, *optional*):
115
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
116
+ """
117
+
118
+
119
+ def relative_position_index_init(window_size: Tuple[int, int]) -> jnp.ndarray:
120
+ """
121
+ get pair-wise relative position index for each token inside the window
122
+ """
123
+ num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
124
+
125
+ coords_h = np.arange(window_size[0])
126
+ coords_w = np.arange(window_size[1])
127
+ coords = np.stack(np.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww
128
+ coords_flatten = np.reshape(coords, (2, -1))
129
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
130
+ relative_coords = np.transpose(relative_coords, (1, 2, 0)) # Wh*Ww, Wh*Ww, 2
131
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
132
+ relative_coords[:, :, 1] += window_size[1] - 1
133
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
134
+
135
+ relative_position_index = np.zeros(shape=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
136
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
137
+ relative_position_index[0, 0:] = num_relative_distance - 3
138
+ relative_position_index[0:, 0] = num_relative_distance - 2
139
+ relative_position_index[0, 0] = num_relative_distance - 1
140
+ return jnp.array(relative_position_index)
141
+
142
+
143
+ def ones_with_scale(key, shape, scale, dtype=jnp.float32):
144
+ return jnp.ones(shape, dtype) * scale
145
+
146
+
147
+ class FlaxBeitDropPath(nn.Module):
148
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
149
+
150
+ rate: float
151
+
152
+ @nn.module.compact
153
+ def __call__(self, inputs, deterministic: Optional[bool] = True):
154
+ if self.rate == 0.0:
155
+ return inputs
156
+ keep_prob = 1.0 - self.rate
157
+ if deterministic:
158
+ return inputs
159
+ else:
160
+ shape = (inputs.shape[0],) + (1,) * (inputs.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
161
+ rng = self.make_rng("droppath")
162
+ random_tensor = keep_prob + jax.random.uniform(rng, shape=shape, dtype=inputs.dtype)
163
+ binary_tensor = jnp.floor(random_tensor)
164
+ output = inputs / keep_prob * binary_tensor
165
+ return output
166
+
167
+
168
+ class FlaxBeitPatchEmbeddings(nn.Module):
169
+ config: BeitConfig
170
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
171
+
172
+ def setup(self):
173
+ self.num_channels = self.config.num_channels
174
+ image_size = self.config.image_size
175
+ patch_size = self.config.patch_size
176
+ num_patches = (image_size // patch_size) * (image_size // patch_size)
177
+ patch_shape = (image_size // patch_size, image_size // patch_size)
178
+ self.num_patches = num_patches
179
+ self.patch_shape = patch_shape
180
+ self.projection = nn.Conv(
181
+ self.config.hidden_size,
182
+ kernel_size=(patch_size, patch_size),
183
+ strides=(patch_size, patch_size),
184
+ padding="VALID",
185
+ dtype=self.dtype,
186
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
187
+ )
188
+
189
+ def __call__(self, pixel_values):
190
+ num_channels = pixel_values.shape[-1]
191
+ if num_channels != self.num_channels:
192
+ raise ValueError(
193
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
194
+ )
195
+ embeddings = self.projection(pixel_values)
196
+ batch_size, _, _, channels = embeddings.shape
197
+ return jnp.reshape(embeddings, (batch_size, -1, channels))
198
+
199
+
200
+ class FlaxBeitEmbeddings(nn.Module):
201
+ """Construct the CLS token, position and patch embeddings."""
202
+
203
+ config: BeitConfig
204
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
205
+
206
+ def setup(self):
207
+ self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
208
+ if self.config.use_mask_token:
209
+ self.mask_token = self.param("mask_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
210
+ self.patch_embeddings = FlaxBeitPatchEmbeddings(self.config, dtype=self.dtype)
211
+ num_patches = self.patch_embeddings.num_patches
212
+ if self.config.use_absolute_position_embeddings:
213
+ self.position_embeddings = self.param(
214
+ "position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size)
215
+ )
216
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
217
+
218
+ def __call__(self, pixel_values, bool_masked_pos=None, deterministic=True):
219
+ embeddings = self.patch_embeddings(pixel_values)
220
+ batch_size, seq_len, _ = embeddings.shape
221
+
222
+ cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size))
223
+ cls_tokens = cls_tokens.astype(embeddings.dtype)
224
+
225
+ if bool_masked_pos is not None:
226
+ mask_tokens = jnp.broadcast_to(self.mask_token, (batch_size, seq_len, self.config.hidden_size))
227
+ mask_tokens = mask_tokens.astype(embeddings.dtype)
228
+ # replace the masked visual tokens by mask_tokens
229
+ w = jnp.expand_dims(bool_masked_pos, axis=-1)
230
+ embeddings = embeddings * (1 - w) + mask_tokens * w
231
+
232
+ embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1)
233
+
234
+ if self.config.use_absolute_position_embeddings:
235
+ embeddings = embeddings + self.position_embeddings.astype(embeddings.dtype)
236
+
237
+ embeddings = self.dropout(embeddings, deterministic=deterministic)
238
+ return embeddings
239
+
240
+
241
+ class FlaxBeitRelativePositionBias(nn.Module):
242
+ config: BeitConfig
243
+ window_size: Tuple[int, int]
244
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
245
+
246
+ def setup(self):
247
+ num_relative_distance = (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) + 3
248
+ self.relative_position_bias_table = self.param(
249
+ "relative_position_bias_table",
250
+ nn.initializers.zeros,
251
+ (num_relative_distance, self.config.num_attention_heads),
252
+ ) # 2*Wh-1 * 2*Ww-1, nH
253
+ # cls to token & token 2 cls & cls to cls
254
+
255
+ self.relative_position_index = relative_position_index_init(self.window_size)
256
+
257
+ def __call__(self):
258
+ index = self.relative_position_index.reshape(-1)
259
+ shape = (self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1)
260
+ relative_position_bias = self.relative_position_bias_table[index].reshape(shape) # Wh*Ww,Wh*Ww,nH
261
+ return jnp.transpose(relative_position_bias, (2, 0, 1))
262
+
263
+
264
+ class FlaxBeitSelfAttention(nn.Module):
265
+ config: BeitConfig
266
+ window_size: Tuple[int, int]
267
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
268
+
269
+ def setup(self):
270
+ if self.config.hidden_size % self.config.num_attention_heads != 0 and not hasattr(
271
+ self.config, "embedding_size"
272
+ ):
273
+ raise ValueError(
274
+ f"The hidden size {self.config.hidden_size,} is not a multiple of the number of attention "
275
+ f"heads {self.config.num_attention_heads}."
276
+ )
277
+
278
+ self.query = nn.Dense(
279
+ self.config.hidden_size,
280
+ dtype=self.dtype,
281
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
282
+ )
283
+ self.key = nn.Dense(
284
+ self.config.hidden_size,
285
+ dtype=self.dtype,
286
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
287
+ use_bias=False,
288
+ )
289
+ self.value = nn.Dense(
290
+ self.config.hidden_size,
291
+ dtype=self.dtype,
292
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
293
+ )
294
+
295
+ self.relative_position_bias = (
296
+ FlaxBeitRelativePositionBias(self.config, window_size=self.window_size, dtype=self.dtype)
297
+ if self.window_size
298
+ else None
299
+ )
300
+
301
+ def __call__(
302
+ self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False
303
+ ):
304
+ head_dim = self.config.hidden_size // self.config.num_attention_heads
305
+
306
+ query_states = self.query(hidden_states).reshape(
307
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
308
+ )
309
+ value_states = self.value(hidden_states).reshape(
310
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
311
+ )
312
+ key_states = self.key(hidden_states).reshape(
313
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
314
+ )
315
+
316
+ dropout_rng = None
317
+ if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
318
+ dropout_rng = self.make_rng("dropout")
319
+
320
+ attention_bias = jnp.array(0.0, dtype=self.dtype)
321
+ # Add relative position bias if present.
322
+ if self.relative_position_bias is not None:
323
+ attention_bias = jnp.expand_dims(self.relative_position_bias(), 0)
324
+ attention_bias = attention_bias.astype(query_states.dtype)
325
+
326
+ # Add shared relative position bias if provided.
327
+ if relative_position_bias is not None:
328
+ attention_bias = attention_bias + relative_position_bias.astype(attention_bias.dtype)
329
+
330
+ attn_weights = dot_product_attention_weights(
331
+ query_states,
332
+ key_states,
333
+ bias=attention_bias,
334
+ dropout_rng=dropout_rng,
335
+ dropout_rate=self.config.attention_probs_dropout_prob,
336
+ broadcast_dropout=True,
337
+ deterministic=deterministic,
338
+ dtype=self.dtype,
339
+ precision=None,
340
+ )
341
+
342
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
343
+ attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
344
+
345
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
346
+ return outputs
347
+
348
+
349
+ class FlaxBeitSelfOutput(nn.Module):
350
+ config: BeitConfig
351
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
352
+
353
+ def setup(self):
354
+ self.dense = nn.Dense(
355
+ self.config.hidden_size,
356
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
357
+ dtype=self.dtype,
358
+ )
359
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
360
+
361
+ def __call__(self, hidden_states, deterministic: bool = True):
362
+ hidden_states = self.dense(hidden_states)
363
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
364
+ return hidden_states
365
+
366
+
367
+ class FlaxBeitAttention(nn.Module):
368
+ config: BeitConfig
369
+ window_size: Tuple[int, int]
370
+ dtype: jnp.dtype = jnp.float32
371
+
372
+ def setup(self):
373
+ self.attention = FlaxBeitSelfAttention(self.config, self.window_size, dtype=self.dtype)
374
+ self.output = FlaxBeitSelfOutput(self.config, dtype=self.dtype)
375
+
376
+ def __call__(
377
+ self, hidden_states, relative_position_bias=None, deterministic=True, output_attentions: bool = False
378
+ ):
379
+ attn_outputs = self.attention(
380
+ hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions
381
+ )
382
+ attn_output = attn_outputs[0]
383
+ attn_output = self.output(attn_output, deterministic=deterministic)
384
+
385
+ outputs = (attn_output,)
386
+
387
+ if output_attentions:
388
+ outputs += (attn_outputs[1],)
389
+
390
+ return outputs
391
+
392
+
393
+ class FlaxBeitIntermediate(nn.Module):
394
+ config: BeitConfig
395
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
396
+
397
+ def setup(self):
398
+ self.dense = nn.Dense(
399
+ self.config.intermediate_size,
400
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
401
+ dtype=self.dtype,
402
+ )
403
+ self.activation = ACT2FN[self.config.hidden_act]
404
+
405
+ def __call__(self, hidden_states):
406
+ hidden_states = self.dense(hidden_states)
407
+ hidden_states = self.activation(hidden_states)
408
+
409
+ return hidden_states
410
+
411
+
412
+ class FlaxBeitOutput(nn.Module):
413
+ config: BeitConfig
414
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
415
+
416
+ def setup(self):
417
+ self.dense = nn.Dense(
418
+ self.config.hidden_size,
419
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
420
+ dtype=self.dtype,
421
+ )
422
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
423
+
424
+ def __call__(self, hidden_states, deterministic: bool = True):
425
+ hidden_states = self.dense(hidden_states)
426
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
427
+
428
+ return hidden_states
429
+
430
+
431
+ class FlaxBeitLayer(nn.Module):
432
+ config: BeitConfig
433
+ window_size: Tuple[int, int]
434
+ drop_path_rate: float
435
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
436
+
437
+ def setup(self):
438
+ self.attention = FlaxBeitAttention(self.config, self.window_size, dtype=self.dtype)
439
+ self.intermediate = FlaxBeitIntermediate(self.config, dtype=self.dtype)
440
+ self.output = FlaxBeitOutput(self.config, dtype=self.dtype)
441
+ self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
442
+ self.drop_path = FlaxBeitDropPath(rate=self.drop_path_rate)
443
+ self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
444
+
445
+ self.init_values = self.config.layer_scale_init_value
446
+ if self.init_values > 0:
447
+ self.lambda_1 = self.param("lambda_1", ones_with_scale, (self.config.hidden_size), self.init_values)
448
+ self.lambda_2 = self.param("lambda_2", ones_with_scale, (self.config.hidden_size), self.init_values)
449
+ else:
450
+ self.lambda_1 = None
451
+ self.lambda_2 = None
452
+
453
+ def __call__(
454
+ self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False
455
+ ):
456
+ self_attention_outputs = self.attention(
457
+ self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
458
+ relative_position_bias,
459
+ deterministic=deterministic,
460
+ output_attentions=output_attentions,
461
+ )
462
+ attention_output = self_attention_outputs[0]
463
+
464
+ # apply lambda_1 if present
465
+ if self.lambda_1 is not None:
466
+ attention_output = self.lambda_1.astype(attention_output.dtype) * attention_output
467
+
468
+ # first residual connection
469
+ hidden_states = self.drop_path(attention_output, deterministic=deterministic) + hidden_states
470
+
471
+ # in BEiT, layernorm is also applied after self-attention
472
+ layer_output = self.layernorm_after(hidden_states)
473
+
474
+ layer_output = self.intermediate(layer_output)
475
+ layer_output = self.output(layer_output, deterministic=deterministic)
476
+
477
+ # apply lambda_2 if present
478
+ if self.lambda_2 is not None:
479
+ layer_output = self.lambda_2.astype(layer_output.dtype) * layer_output
480
+
481
+ # second residual connection
482
+ layer_output = self.drop_path(layer_output, deterministic=deterministic) + hidden_states
483
+
484
+ outputs = (layer_output,)
485
+
486
+ if output_attentions:
487
+ outputs += (self_attention_outputs[1],)
488
+
489
+ return outputs
490
+
491
+
492
+ class FlaxBeitLayerCollection(nn.Module):
493
+ config: BeitConfig
494
+ window_size: Tuple[int, int]
495
+ drop_path_rates: List[float]
496
+ relative_position_bias: Callable[[], jnp.ndarray]
497
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
498
+
499
+ def setup(self):
500
+ self.layers = [
501
+ FlaxBeitLayer(
502
+ self.config,
503
+ window_size=self.window_size if self.config.use_relative_position_bias else None,
504
+ drop_path_rate=self.drop_path_rates[i],
505
+ name=str(i),
506
+ dtype=self.dtype,
507
+ )
508
+ for i in range(self.config.num_hidden_layers)
509
+ ]
510
+
511
+ def __call__(
512
+ self,
513
+ hidden_states,
514
+ deterministic: bool = True,
515
+ output_attentions: bool = False,
516
+ output_hidden_states: bool = False,
517
+ return_dict: bool = True,
518
+ ):
519
+ all_attentions = () if output_attentions else None
520
+ all_hidden_states = () if output_hidden_states else None
521
+
522
+ for i, layer in enumerate(self.layers):
523
+ if output_hidden_states:
524
+ all_hidden_states += (hidden_states,)
525
+ relative_position_bias = self.relative_position_bias() if self.relative_position_bias is not None else None
526
+ layer_outputs = layer(
527
+ hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions
528
+ )
529
+
530
+ hidden_states = layer_outputs[0]
531
+
532
+ if output_attentions:
533
+ all_attentions += (layer_outputs[1],)
534
+
535
+ if output_hidden_states:
536
+ all_hidden_states += (hidden_states,)
537
+
538
+ outputs = (hidden_states,)
539
+ if not return_dict:
540
+ return tuple(v for v in outputs if v is not None)
541
+
542
+ return FlaxBaseModelOutput(
543
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
544
+ )
545
+
546
+
547
+ class FlaxBeitEncoder(nn.Module):
548
+ config: BeitConfig
549
+ window_size: Tuple[int, int]
550
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
551
+
552
+ def setup(self):
553
+ if self.config.use_shared_relative_position_bias:
554
+ self.relative_position_bias = FlaxBeitRelativePositionBias(
555
+ config=self.config, window_size=self.window_size, dtype=self.dtype
556
+ )
557
+
558
+ # stochastic depth decay rule
559
+ drop_path_rates = list(np.linspace(0, self.config.drop_path_rate, self.config.num_hidden_layers))
560
+ self.layer = FlaxBeitLayerCollection(
561
+ self.config,
562
+ window_size=self.window_size,
563
+ drop_path_rates=drop_path_rates,
564
+ relative_position_bias=self.relative_position_bias
565
+ if self.config.use_shared_relative_position_bias
566
+ else None,
567
+ dtype=self.dtype,
568
+ )
569
+
570
+ def __call__(
571
+ self,
572
+ hidden_states,
573
+ deterministic: bool = True,
574
+ output_attentions: bool = False,
575
+ output_hidden_states: bool = False,
576
+ return_dict: bool = True,
577
+ ):
578
+ return self.layer(
579
+ hidden_states,
580
+ deterministic=deterministic,
581
+ output_attentions=output_attentions,
582
+ output_hidden_states=output_hidden_states,
583
+ return_dict=return_dict,
584
+ )
585
+
586
+
587
+ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
588
+ """
589
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
590
+ models.
591
+ """
592
+
593
+ config_class = BeitConfig
594
+ base_model_prefix = "beit"
595
+ main_input_name = "pixel_values"
596
+ module_class: nn.Module = None
597
+
598
+ def __init__(
599
+ self,
600
+ config: BeitConfig,
601
+ input_shape=None,
602
+ seed: int = 0,
603
+ dtype: jnp.dtype = jnp.float32,
604
+ _do_init: bool = True,
605
+ **kwargs,
606
+ ):
607
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
608
+ if input_shape is None:
609
+ input_shape = (1, config.image_size, config.image_size, config.num_channels)
610
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
611
+
612
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
613
+ # init input tensors
614
+ pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
615
+
616
+ params_rng, dropout_rng = jax.random.split(rng)
617
+ dropout_rng, droppath_rng = jax.random.split(dropout_rng)
618
+ rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng}
619
+
620
+ random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]
621
+
622
+ if params is not None:
623
+ random_params = flatten_dict(unfreeze(random_params))
624
+ params = flatten_dict(unfreeze(params))
625
+ for missing_key in self._missing_keys:
626
+ params[missing_key] = random_params[missing_key]
627
+ self._missing_keys = set()
628
+ return freeze(unflatten_dict(params))
629
+ else:
630
+ return random_params
631
+
632
+ @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
633
+ def __call__(
634
+ self,
635
+ pixel_values,
636
+ bool_masked_pos=None,
637
+ params: dict = None,
638
+ dropout_rng: jax.random.PRNGKey = None,
639
+ train: bool = False,
640
+ output_attentions: Optional[bool] = None,
641
+ output_hidden_states: Optional[bool] = None,
642
+ return_dict: Optional[bool] = None,
643
+ ):
644
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
645
+ output_hidden_states = (
646
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
647
+ )
648
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
649
+
650
+ pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
651
+ # Handle any PRNG if needed
652
+ rngs = {}
653
+ if dropout_rng is not None:
654
+ dropout_rng, droppath_rng = jax.random.split(dropout_rng)
655
+ rngs["dropout"] = dropout_rng
656
+ rngs["droppath"] = droppath_rng
657
+
658
+ return self.module.apply(
659
+ {"params": params or self.params},
660
+ jnp.array(pixel_values, dtype=jnp.float32),
661
+ bool_masked_pos,
662
+ not train,
663
+ output_attentions,
664
+ output_hidden_states,
665
+ return_dict,
666
+ rngs=rngs,
667
+ )
668
+
669
+
670
+ class FlaxBeitPooler(nn.Module):
671
+ config: BeitConfig
672
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
673
+
674
+ def setup(self):
675
+ if self.config.use_mean_pooling:
676
+ self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
677
+
678
+ def __call__(self, hidden_states):
679
+ if self.config.use_mean_pooling:
680
+ # Mean pool the final hidden states of the patch tokens
681
+ patch_tokens = hidden_states[:, 1:, :]
682
+ pooled_output = self.layernorm(jnp.mean(patch_tokens, axis=1))
683
+ else:
684
+ # Pool by simply taking the final hidden state of the [CLS] token
685
+ pooled_output = hidden_states[:, 0]
686
+
687
+ return pooled_output
688
+
689
+
690
+ class FlaxBeitModule(nn.Module):
691
+ config: BeitConfig
692
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
693
+ add_pooling_layer: bool = True
694
+
695
+ def setup(self):
696
+ self.embeddings = FlaxBeitEmbeddings(self.config, dtype=self.dtype)
697
+ self.encoder = FlaxBeitEncoder(
698
+ self.config, window_size=self.embeddings.patch_embeddings.patch_shape, dtype=self.dtype
699
+ )
700
+ if not self.config.use_mean_pooling:
701
+ self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
702
+ self.pooler = FlaxBeitPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None
703
+
704
+ def __call__(
705
+ self,
706
+ pixel_values,
707
+ bool_masked_pos=None,
708
+ deterministic: bool = True,
709
+ output_attentions: bool = False,
710
+ output_hidden_states: bool = False,
711
+ return_dict: bool = True,
712
+ ):
713
+ hidden_states = self.embeddings(pixel_values, bool_masked_pos, deterministic=deterministic)
714
+
715
+ outputs = self.encoder(
716
+ hidden_states,
717
+ deterministic=deterministic,
718
+ output_attentions=output_attentions,
719
+ output_hidden_states=output_hidden_states,
720
+ return_dict=return_dict,
721
+ )
722
+ hidden_states = outputs[0]
723
+ if not self.config.use_mean_pooling:
724
+ hidden_states = self.layernorm(hidden_states)
725
+ pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
726
+
727
+ if not return_dict:
728
+ # if pooled is None, don't return it
729
+ if pooled is None:
730
+ return (hidden_states,) + outputs[1:]
731
+ return (hidden_states, pooled) + outputs[1:]
732
+
733
+ return FlaxBeitModelOutputWithPooling(
734
+ last_hidden_state=hidden_states,
735
+ pooler_output=pooled,
736
+ hidden_states=outputs.hidden_states,
737
+ attentions=outputs.attentions,
738
+ )
739
+
740
+
741
+ @add_start_docstrings(
742
+ "The bare Beit Model transformer outputting raw hidden-states without any specific head on top.",
743
+ BEIT_START_DOCSTRING,
744
+ )
745
+ class FlaxBeitModel(FlaxBeitPreTrainedModel):
746
+ module_class = FlaxBeitModule
747
+
748
+
749
+ FLAX_BEIT_MODEL_DOCSTRING = """
750
+ Returns:
751
+
752
+ Examples:
753
+
754
+ ```python
755
+ >>> from transformers import AutoImageProcessor, FlaxBeitModel
756
+ >>> from PIL import Image
757
+ >>> import requests
758
+
759
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
760
+ >>> image = Image.open(requests.get(url, stream=True).raw)
761
+
762
+ >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")
763
+ >>> model = FlaxBeitModel.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")
764
+
765
+ >>> inputs = image_processor(images=image, return_tensors="np")
766
+ >>> outputs = model(**inputs)
767
+ >>> last_hidden_states = outputs.last_hidden_state
768
+ ```
769
+ """
770
+
771
+ overwrite_call_docstring(FlaxBeitModel, FLAX_BEIT_MODEL_DOCSTRING)
772
+ append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBeitModelOutputWithPooling, config_class=BeitConfig)
773
+
774
+
775
+ class FlaxBeitForMaskedImageModelingModule(nn.Module):
776
+ config: BeitConfig
777
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
778
+
779
+ def setup(self):
780
+ self.beit = FlaxBeitModule(self.config, add_pooling_layer=False, dtype=self.dtype)
781
+
782
+ # Classifier head
783
+ self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
784
+ self.lm_head = nn.Dense(
785
+ self.config.vocab_size,
786
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
787
+ dtype=self.dtype,
788
+ )
789
+
790
+ def __call__(
791
+ self,
792
+ pixel_values=None,
793
+ bool_masked_pos=None,
794
+ deterministic: bool = True,
795
+ output_attentions=None,
796
+ output_hidden_states=None,
797
+ return_dict=None,
798
+ ):
799
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
800
+
801
+ outputs = self.beit(
802
+ pixel_values,
803
+ bool_masked_pos,
804
+ deterministic=deterministic,
805
+ output_attentions=output_attentions,
806
+ output_hidden_states=output_hidden_states,
807
+ return_dict=return_dict,
808
+ )
809
+
810
+ sequence_output = outputs[0]
811
+ sequence_output = self.layernorm(sequence_output)
812
+ prediction_scores = self.lm_head(sequence_output[:, 1:])
813
+
814
+ if not return_dict:
815
+ output = (prediction_scores,) + outputs[2:]
816
+ return output
817
+
818
+ return FlaxMaskedLMOutput(
819
+ logits=prediction_scores,
820
+ hidden_states=outputs.hidden_states,
821
+ attentions=outputs.attentions,
822
+ )
823
+
824
+
825
+ @add_start_docstrings(
826
+ "Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).",
827
+ BEIT_START_DOCSTRING,
828
+ )
829
+ class FlaxBeitForMaskedImageModeling(FlaxBeitPreTrainedModel):
830
+ module_class = FlaxBeitForMaskedImageModelingModule
831
+
832
+
833
+ FLAX_BEIT_MLM_DOCSTRING = """
834
+ bool_masked_pos (`numpy.ndarray` of shape `(batch_size, num_patches)`):
835
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
836
+
837
+ Returns:
838
+
839
+ Examples:
840
+
841
+ ```python
842
+ >>> from transformers import AutoImageProcessor, BeitForMaskedImageModeling
843
+ >>> from PIL import Image
844
+ >>> import requests
845
+
846
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
847
+ >>> image = Image.open(requests.get(url, stream=True).raw)
848
+
849
+ >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
850
+ >>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
851
+
852
+ >>> inputs = image_processor(images=image, return_tensors="np")
853
+ >>> outputs = model(**inputs)
854
+ >>> logits = outputs.logits
855
+ ```
856
+ """
857
+
858
+ overwrite_call_docstring(FlaxBeitForMaskedImageModeling, FLAX_BEIT_MLM_DOCSTRING)
859
+ append_replace_return_docstrings(
860
+ FlaxBeitForMaskedImageModeling, output_type=FlaxMaskedLMOutput, config_class=BeitConfig
861
+ )
862
+
863
+
864
+ class FlaxBeitForImageClassificationModule(nn.Module):
865
+ config: BeitConfig
866
+ dtype: jnp.dtype = jnp.float32
867
+
868
+ def setup(self):
869
+ self.beit = FlaxBeitModule(config=self.config, dtype=self.dtype, add_pooling_layer=True)
870
+ self.classifier = nn.Dense(
871
+ self.config.num_labels,
872
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
873
+ dtype=self.dtype,
874
+ )
875
+
876
+ def __call__(
877
+ self,
878
+ pixel_values=None,
879
+ bool_masked_pos=None,
880
+ deterministic: bool = True,
881
+ output_attentions=None,
882
+ output_hidden_states=None,
883
+ return_dict=None,
884
+ ):
885
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
886
+
887
+ outputs = self.beit(
888
+ pixel_values,
889
+ deterministic=deterministic,
890
+ output_attentions=output_attentions,
891
+ output_hidden_states=output_hidden_states,
892
+ return_dict=return_dict,
893
+ )
894
+
895
+ pooled_output = outputs[1]
896
+ logits = self.classifier(pooled_output)
897
+
898
+ if not return_dict:
899
+ output = (logits,) + outputs[2:]
900
+ return output
901
+
902
+ return FlaxSequenceClassifierOutput(
903
+ logits=logits,
904
+ hidden_states=outputs.hidden_states,
905
+ attentions=outputs.attentions,
906
+ )
907
+
908
+
909
+ @add_start_docstrings(
910
+ """
911
+ Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final
912
+ hidden states of the patch tokens) e.g. for ImageNet.
913
+ """,
914
+ BEIT_START_DOCSTRING,
915
+ )
916
+ class FlaxBeitForImageClassification(FlaxBeitPreTrainedModel):
917
+ module_class = FlaxBeitForImageClassificationModule
918
+
919
+
920
+ FLAX_BEIT_CLASSIF_DOCSTRING = """
921
+ Returns:
922
+
923
+ Example:
924
+
925
+ ```python
926
+ >>> from transformers import AutoImageProcessor, FlaxBeitForImageClassification
927
+ >>> from PIL import Image
928
+ >>> import requests
929
+
930
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
931
+ >>> image = Image.open(requests.get(url, stream=True).raw)
932
+
933
+ >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224")
934
+ >>> model = FlaxBeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224")
935
+
936
+ >>> inputs = image_processor(images=image, return_tensors="np")
937
+ >>> outputs = model(**inputs)
938
+ >>> logits = outputs.logits
939
+ >>> # model predicts one of the 1000 ImageNet classes
940
+ >>> predicted_class_idx = logits.argmax(-1).item()
941
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
942
+ ```
943
+ """
944
+
945
+ overwrite_call_docstring(FlaxBeitForImageClassification, FLAX_BEIT_CLASSIF_DOCSTRING)
946
+ append_replace_return_docstrings(
947
+ FlaxBeitForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=BeitConfig
948
+ )
949
+
950
+
951
+ __all__ = [
952
+ "FlaxBeitForImageClassification",
953
+ "FlaxBeitForMaskedImageModeling",
954
+ "FlaxBeitModel",
955
+ "FlaxBeitPreTrainedModel",
956
+ ]
.venv/lib/python3.11/site-packages/transformers/models/code_llama/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .tokenization_code_llama import *
22
+ from .tokenization_code_llama_fast import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
.venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (786 Bytes). View file
 
.venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/tokenization_code_llama.cpython-311.pyc ADDED
Binary file (22.5 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/code_llama/__pycache__/tokenization_code_llama_fast.cpython-311.pyc ADDED
Binary file (18 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 MetaAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """Tokenization classes for Code LLaMA."""
18
+
19
+ import os
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, List, Optional, Tuple
22
+
23
+ import sentencepiece as spm
24
+
25
+ from ...convert_slow_tokenizer import import_protobuf
26
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
27
+ from ...utils import logging, requires_backends
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
33
+
34
+ SPIECE_UNDERLINE = "▁"
35
+
36
+ B_INST, E_INST = "[INST]", "[/INST]"
37
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
38
+
39
+ # fmt: off
40
+ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
41
+ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
42
+ that your responses are socially unbiased and positive in nature.
43
+
44
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
45
+ correct. If you don't know the answer to a question, please don't share false information."""
46
+ # fmt: on
47
+
48
+
49
+ class CodeLlamaTokenizer(PreTrainedTokenizer):
50
+ """
51
+ Construct a CodeLlama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as
52
+ there is no padding token in the original model.
53
+
54
+ The default configuration match that of
55
+ [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/meta-llama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json)
56
+ which supports prompt infilling.
57
+
58
+ Args:
59
+ vocab_file (`str`):
60
+ Path to the vocabulary file.
61
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
62
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
63
+ token instead.
64
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
65
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
66
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
67
+ The end of sequence token.
68
+
69
+ <Tip>
70
+
71
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
72
+ The token used is the `sep_token`.
73
+
74
+ </Tip>
75
+
76
+ prefix_token (`str`, *optional*, defaults to `"▁<PRE>"`):
77
+ Prefix token used for infilling.
78
+ middle_token (`str`, *optional*, defaults to `"▁<MID>"`):
79
+ Middle token used for infilling.
80
+ suffix_token (`str`, *optional*, defaults to `"▁<SUF>"`):
81
+ Suffix token used for infilling.
82
+ eot_token (`str`, *optional*, defaults to `"▁<EOT>"`):
83
+ End of text token used for infilling.
84
+ fill_token (`str`, *optional*, defaults to `"<FILL_ME>"`):
85
+ The token used to split the input between the prefix and suffix.
86
+ suffix_first (`bool`, *optional*, defaults to `False`):
87
+ Whether the input prompt and suffix should be formatted with the suffix first.
88
+ sp_model_kwargs (`dict`, *optional*):
89
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
90
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
91
+ to set:
92
+
93
+ - `enable_sampling`: Enable subword regularization.
94
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
95
+
96
+ - `nbest_size = {0,1}`: No sampling is performed.
97
+ - `nbest_size > 1`: samples from the nbest_size results.
98
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
99
+ using forward-filtering-and-backward-sampling algorithm.
100
+
101
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
102
+ BPE-dropout.
103
+ add_bos_token (`bool`, *optional*, defaults to `True`):
104
+ Whether to add a beginning of sequence token at the start of sequences.
105
+ add_eos_token (`bool`, *optional*, defaults to `False`):
106
+ Whether to add an end of sequence token at the end of sequences.
107
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
108
+ Whether or not to clean up the tokenization spaces.
109
+ additional_special_tokens (`List[str]`, *optional*):
110
+ Additional special tokens used by the tokenizer.
111
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
112
+ Whether or not the default system prompt for Llama should be used.
113
+ """
114
+
115
+ vocab_files_names = VOCAB_FILES_NAMES
116
+ model_input_names = ["input_ids", "attention_mask"]
117
+
118
+ def __init__(
119
+ self,
120
+ vocab_file,
121
+ unk_token="<unk>",
122
+ bos_token="<s>",
123
+ eos_token="</s>",
124
+ prefix_token="▁<PRE>",
125
+ middle_token="▁<MID>",
126
+ suffix_token="▁<SUF>",
127
+ eot_token="▁<EOT>",
128
+ fill_token="<FILL_ME>",
129
+ suffix_first=False,
130
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
131
+ add_bos_token=True,
132
+ add_eos_token=False,
133
+ clean_up_tokenization_spaces=False,
134
+ additional_special_tokens=None,
135
+ use_default_system_prompt=False,
136
+ **kwargs,
137
+ ):
138
+ requires_backends(self, "protobuf")
139
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
140
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
141
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
142
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
143
+
144
+ self.use_default_system_prompt = use_default_system_prompt
145
+ # mark tokens special to skip them
146
+ additional_special_tokens = additional_special_tokens or []
147
+ for token in [prefix_token, middle_token, suffix_token, eot_token]:
148
+ additional_special_tokens += [token] if token is not None else []
149
+
150
+ self.vocab_file = vocab_file
151
+ self.add_bos_token = add_bos_token
152
+ self.add_eos_token = add_eos_token
153
+ self._prefix_token = prefix_token
154
+ self._middle_token = middle_token
155
+ self._suffix_token = suffix_token
156
+ self._eot_token = eot_token
157
+ self.fill_token = fill_token
158
+ self.suffix_first = suffix_first
159
+ self.sp_model = self.get_spm_processor()
160
+
161
+ super().__init__(
162
+ bos_token=bos_token,
163
+ eos_token=eos_token,
164
+ unk_token=unk_token,
165
+ add_bos_token=add_bos_token,
166
+ add_eos_token=add_eos_token,
167
+ prefix_token=prefix_token,
168
+ middle_token=middle_token,
169
+ suffix_token=suffix_token,
170
+ eot_token=eot_token,
171
+ fill_token=fill_token,
172
+ sp_model_kwargs=self.sp_model_kwargs,
173
+ suffix_first=suffix_first,
174
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
175
+ additional_special_tokens=additional_special_tokens,
176
+ use_default_system_prompt=use_default_system_prompt,
177
+ **kwargs,
178
+ )
179
+
180
+ @property
181
+ def unk_token_length(self):
182
+ return len(self.sp_model.encode(str(self.unk_token)))
183
+
184
+ def get_spm_processor(self):
185
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
186
+ with open(self.vocab_file, "rb") as f:
187
+ sp_model = f.read()
188
+ model_pb2 = import_protobuf()
189
+ model = model_pb2.ModelProto.FromString(sp_model)
190
+ normalizer_spec = model_pb2.NormalizerSpec()
191
+ normalizer_spec.add_dummy_prefix = False
192
+ model.normalizer_spec.MergeFrom(normalizer_spec)
193
+ sp_model = model.SerializeToString()
194
+ tokenizer.LoadFromSerializedProto(sp_model)
195
+ return tokenizer
196
+
197
+ @property
198
+ def prefix_token(self):
199
+ return self._prefix_token
200
+
201
+ @property
202
+ def prefix_id(self):
203
+ if self._prefix_token is None:
204
+ return None
205
+ return self.convert_tokens_to_ids(self.prefix_token)
206
+
207
+ @property
208
+ def middle_token(self):
209
+ return self._middle_token
210
+
211
+ @property
212
+ def middle_id(self):
213
+ if self._middle_token is None:
214
+ return None
215
+ return self.convert_tokens_to_ids(self.middle_token)
216
+
217
+ @property
218
+ def suffix_token(self):
219
+ return self._suffix_token
220
+
221
+ @property
222
+ def suffix_id(self):
223
+ if self._suffix_token is None:
224
+ return None
225
+ return self.convert_tokens_to_ids(self.suffix_token)
226
+
227
+ @property
228
+ def eot_token(self):
229
+ return self._eot_token
230
+
231
+ @property
232
+ def eot_id(self):
233
+ if self._eot_token is None:
234
+ return None
235
+ return self.convert_tokens_to_ids(self.eot_token)
236
+
237
+ @property
238
+ def vocab_size(self):
239
+ """Returns vocab size"""
240
+ return self.sp_model.get_piece_size()
241
+
242
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab
243
+ def get_vocab(self):
244
+ """Returns vocab as a dict"""
245
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
246
+ vocab.update(self.added_tokens_encoder)
247
+ return vocab
248
+
249
+ def tokenize(self, prefix, suffix=None, suffix_first=False, **kwargs) -> List[int]:
250
+ # add a prefix space to `prefix`
251
+ if self.fill_token is not None and self.fill_token in prefix and suffix is None:
252
+ prefix, suffix = prefix.split(self.fill_token)
253
+
254
+ if len(prefix) > 0:
255
+ prefix = SPIECE_UNDERLINE + prefix.replace(SPIECE_UNDERLINE, " ")
256
+
257
+ if suffix is None or len(suffix) < 1:
258
+ tokens = super().tokenize(prefix, **kwargs)
259
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
260
+ tokens = tokens[1:]
261
+ return tokens
262
+
263
+ prefix_tokens = self._tokenize(prefix) # prefix has an extra `SPIECE_UNDERLINE`
264
+
265
+ if None in (self.prefix_id, self.middle_id, self.suffix_id):
266
+ raise ValueError(
267
+ "The input either includes a `prefix` and a `suffix` used for the infilling task,"
268
+ f" or can be split on the {self.fill_token} token, creating a suffix and prefix,"
269
+ " but the model does not support `infilling`."
270
+ )
271
+ suffix_tokens = self._tokenize(suffix) # make sure CodeLlama sp model does not mess up
272
+
273
+ suffix_first = suffix_first if suffix_first is not None else self.suffix_first
274
+ if suffix_first:
275
+ # format as " <PRE> <SUF>{suf} <MID> {pre}"
276
+ return [self.prefix_token, self.suffix_token] + suffix_tokens + [self.middle_token] + prefix_tokens
277
+ else:
278
+ # format as " <PRE> {pre} <SUF>{suf} <MID>"
279
+ return [self.prefix_token] + prefix_tokens + [self.suffix_token] + suffix_tokens + [self.middle_token]
280
+
281
+ def _tokenize(self, text, **kwargs):
282
+ """
283
+ Returns a tokenized string.
284
+
285
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
286
+ SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
287
+ `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
288
+ `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
289
+ `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
290
+ """
291
+ tokens = self.sp_model.encode(text, out_type=str)
292
+ if not text.startswith((SPIECE_UNDERLINE, " ")):
293
+ return tokens
294
+ # 1. Encode string + prefix ex: "<unk> Hey"
295
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
296
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
297
+ return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
298
+
299
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_token_to_id
300
+ def _convert_token_to_id(self, token):
301
+ """Converts a token (str) in an id using the vocab."""
302
+ return self.sp_model.piece_to_id(token)
303
+
304
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_id_to_token
305
+ def _convert_id_to_token(self, index):
306
+ """Converts an index (integer) in a token (str) using the vocab."""
307
+ token = self.sp_model.IdToPiece(index)
308
+ return token
309
+
310
+ def convert_tokens_to_string(self, tokens):
311
+ """Converts a sequence of tokens (string) in a single string."""
312
+ # since we manually add the prefix space, we have to remove it when decoding
313
+ if tokens[0].startswith(SPIECE_UNDERLINE):
314
+ tokens[0] = tokens[0][1:]
315
+
316
+ current_sub_tokens = []
317
+ out_string = ""
318
+ for _, token in enumerate(tokens):
319
+ # make sure that special tokens are not decoded using sentencepiece model
320
+ if token in self.all_special_tokens:
321
+ out_string += self.sp_model.decode(current_sub_tokens) + token
322
+ current_sub_tokens = []
323
+ else:
324
+ current_sub_tokens.append(token)
325
+ out_string += self.sp_model.decode(current_sub_tokens)
326
+ return out_string
327
+
328
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.save_vocabulary
329
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
330
+ """
331
+ Save the vocabulary and special tokens file to a directory.
332
+
333
+ Args:
334
+ save_directory (`str`):
335
+ The directory in which to save the vocabulary.
336
+
337
+ Returns:
338
+ `Tuple(str)`: Paths to the files saved.
339
+ """
340
+ if not os.path.isdir(save_directory):
341
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
342
+ return
343
+ out_vocab_file = os.path.join(
344
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
345
+ )
346
+
347
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
348
+ copyfile(self.vocab_file, out_vocab_file)
349
+ elif not os.path.isfile(self.vocab_file):
350
+ with open(out_vocab_file, "wb") as fi:
351
+ content_spiece_model = self.sp_model.serialized_model_proto()
352
+ fi.write(content_spiece_model)
353
+
354
+ return (out_vocab_file,)
355
+
356
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
357
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
358
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
359
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
360
+
361
+ output = bos_token_id + token_ids_0 + eos_token_id
362
+
363
+ if token_ids_1 is not None:
364
+ output = output + bos_token_id + token_ids_1 + eos_token_id
365
+
366
+ return output
367
+
368
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
369
+ def get_special_tokens_mask(
370
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
371
+ ) -> List[int]:
372
+ """
373
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
374
+ special tokens using the tokenizer `prepare_for_model` method.
375
+
376
+ Args:
377
+ token_ids_0 (`List[int]`):
378
+ List of IDs.
379
+ token_ids_1 (`List[int]`, *optional*):
380
+ Optional second list of IDs for sequence pairs.
381
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
382
+ Whether or not the token list is already formatted with special tokens for the model.
383
+
384
+ Returns:
385
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
386
+ """
387
+ if already_has_special_tokens:
388
+ return super().get_special_tokens_mask(
389
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
390
+ )
391
+
392
+ bos_token_id = [1] if self.add_bos_token else []
393
+ eos_token_id = [1] if self.add_eos_token else []
394
+
395
+ if token_ids_1 is None:
396
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
397
+ return (
398
+ bos_token_id
399
+ + ([0] * len(token_ids_0))
400
+ + eos_token_id
401
+ + bos_token_id
402
+ + ([0] * len(token_ids_1))
403
+ + eos_token_id
404
+ )
405
+
406
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
407
+ def create_token_type_ids_from_sequences(
408
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
409
+ ) -> List[int]:
410
+ """
411
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
412
+ sequence pair mask has the following format:
413
+
414
+ ```
415
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
416
+ | first sequence | second sequence |
417
+ ```
418
+
419
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
420
+
421
+ Args:
422
+ token_ids_0 (`List[int]`):
423
+ List of ids.
424
+ token_ids_1 (`List[int]`, *optional*):
425
+ Optional second list of IDs for sequence pairs.
426
+
427
+ Returns:
428
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
429
+ """
430
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
431
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
432
+
433
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
434
+
435
+ if token_ids_1 is not None:
436
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
437
+
438
+ return output
439
+
440
+ def __getstate__(self):
441
+ state = self.__dict__.copy()
442
+ state["sp_model"] = None
443
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
444
+ return state
445
+
446
+ def __setstate__(self, d):
447
+ self.__dict__ = d
448
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
449
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
450
+
451
+
452
+ __all__ = ["CodeLlamaTokenizer"]
.venv/lib/python3.11/site-packages/transformers/models/code_llama/tokenization_code_llama_fast.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import os
16
+ from shutil import copyfile
17
+ from typing import List, Optional, Tuple
18
+
19
+ from tokenizers import normalizers, processors
20
+
21
+ from ...tokenization_utils_fast import PreTrainedTokenizerFast
22
+ from ...utils import is_sentencepiece_available, logging
23
+ from ...utils.versions import require_version
24
+
25
+
26
+ require_version("tokenizers>=0.13.3")
27
+
28
+ if is_sentencepiece_available():
29
+ from .tokenization_code_llama import CodeLlamaTokenizer
30
+ else:
31
+ CodeLlamaTokenizer = None
32
+
33
+ logger = logging.get_logger(__name__)
34
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
35
+
36
+ SPIECE_UNDERLINE = "▁"
37
+
38
+
39
+ B_INST, E_INST = "[INST]", "[/INST]"
40
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
41
+
42
+ # fmt: off
43
+ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
44
+ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
45
+ that your responses are socially unbiased and positive in nature.
46
+
47
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
48
+ correct. If you don't know the answer to a question, please don't share false information."""
49
+ # fmt: on
50
+
51
+
52
+ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
53
+ """
54
+ Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
55
+
56
+ This uses notably ByteFallback and no normalization.
57
+
58
+ ```python
59
+ >>> from transformers import CodeLlamaTokenizerFast
60
+
61
+ >>> tokenizer = CodeLlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
62
+ >>> tokenizer.encode("Hello this is a test")
63
+ [1, 15043, 445, 338, 263, 1243]
64
+ ```
65
+
66
+ If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
67
+ call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
68
+ values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
69
+ [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
70
+
71
+
72
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
73
+ refer to this superclass for more information regarding those methods. The default configuration match that of
74
+ [meta-llama/CodeLlama-7b-Instruct-hf](https://huggingface.co/meta-llama/CodeLlama-7b-Instruct-hf/blob/main/tokenizer_config.json)
75
+ which supports prompt infilling.
76
+
77
+ Args:
78
+ vocab_file (`str`, *optional*):
79
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
80
+ contains the vocabulary necessary to instantiate a tokenizer.
81
+ tokenizer_file (`str`, *optional*):
82
+ [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
83
+ contains everything needed to load the tokenizer.
84
+ clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
85
+ Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
86
+ spaces.
87
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
88
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
89
+ token instead.
90
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
91
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
92
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
93
+ The end of sequence token.
94
+ prefix_token (`str`, *optional*, defaults to `"▁<PRE>"`):
95
+ Prefix token used for infilling.
96
+ middle_token (`str`, *optional*, defaults to `"▁<MID>"`):
97
+ Middle token used for infilling.
98
+ suffix_token (`str`, *optional*, defaults to `"▁<SUF>"`):
99
+ Suffix token used for infilling.
100
+ eot_token (`str`, *optional*, defaults to `"▁<EOT>"`):
101
+ End of text token used for infilling.
102
+ fill_token (`str`, *optional*, defaults to `"<FILL_ME>"`):
103
+ The token used to split the input between the prefix and suffix.
104
+ additional_special_tokens (`List[str]`, *optional*):
105
+ Additional special tokens used by the tokenizer.
106
+ add_bos_token (`bool`, *optional*, defaults to `True`):
107
+ Whether to add a beginning of sequence token at the start of sequences.
108
+ add_eos_token (`bool`, *optional*, defaults to `False`):
109
+ Whether to add an end of sequence token at the end of sequences.
110
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
111
+ Whether or not the default system prompt for Llama should be used.
112
+ """
113
+
114
+ vocab_files_names = VOCAB_FILES_NAMES
115
+ slow_tokenizer_class = CodeLlamaTokenizer
116
+ padding_side = "left"
117
+ model_input_names = ["input_ids", "attention_mask"]
118
+
119
+ def __init__(
120
+ self,
121
+ vocab_file=None,
122
+ tokenizer_file=None,
123
+ clean_up_tokenization_spaces=False,
124
+ unk_token="<unk>",
125
+ bos_token="<s>",
126
+ eos_token="</s>",
127
+ prefix_token="▁<PRE>",
128
+ middle_token="▁<MID>",
129
+ suffix_token="▁<SUF>",
130
+ eot_token="▁<EOT>",
131
+ fill_token="<FILL_ME>",
132
+ additional_special_tokens=None,
133
+ add_bos_token=True,
134
+ add_eos_token=False,
135
+ use_default_system_prompt=False,
136
+ **kwargs,
137
+ ):
138
+ # mark tokens special to skip them
139
+ additional_special_tokens = additional_special_tokens or []
140
+ for token in [prefix_token, middle_token, suffix_token, eot_token]:
141
+ additional_special_tokens += [token] if token is not None else []
142
+ self.use_default_system_prompt = use_default_system_prompt
143
+
144
+ super().__init__(
145
+ vocab_file=vocab_file,
146
+ tokenizer_file=tokenizer_file,
147
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
148
+ additional_special_tokens=additional_special_tokens,
149
+ unk_token=unk_token,
150
+ bos_token=bos_token,
151
+ eos_token=eos_token,
152
+ add_bos_token=add_bos_token,
153
+ add_eos_token=add_eos_token,
154
+ prefix_token=prefix_token,
155
+ middle_token=middle_token,
156
+ suffix_token=suffix_token,
157
+ eot_token=eot_token,
158
+ fill_token=fill_token,
159
+ use_default_system_prompt=use_default_system_prompt,
160
+ **kwargs,
161
+ )
162
+ self._add_bos_token = add_bos_token
163
+ self._add_eos_token = add_eos_token
164
+ self.update_post_processor()
165
+
166
+ self.vocab_file = vocab_file
167
+
168
+ self._prefix_token = prefix_token
169
+ self._middle_token = middle_token
170
+ self._suffix_token = suffix_token
171
+ self._eot_token = eot_token
172
+ self.fill_token = fill_token
173
+
174
+ @property
175
+ def can_save_slow_tokenizer(self) -> bool:
176
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
177
+
178
+ # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.update_post_processor
179
+ def update_post_processor(self):
180
+ """
181
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
182
+ """
183
+ bos = self.bos_token
184
+ bos_token_id = self.bos_token_id
185
+ if bos is None and self.add_bos_token:
186
+ raise ValueError("add_bos_token = True but bos_token = None")
187
+
188
+ eos = self.eos_token
189
+ eos_token_id = self.eos_token_id
190
+ if eos is None and self.add_eos_token:
191
+ raise ValueError("add_eos_token = True but eos_token = None")
192
+
193
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
194
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
195
+
196
+ special_tokens = []
197
+ if self.add_bos_token:
198
+ special_tokens.append((bos, bos_token_id))
199
+ if self.add_eos_token:
200
+ special_tokens.append((eos, eos_token_id))
201
+ self._tokenizer.post_processor = processors.TemplateProcessing(
202
+ single=single, pair=pair, special_tokens=special_tokens
203
+ )
204
+
205
+ @property
206
+ def prefix_token(self):
207
+ return self._prefix_token
208
+
209
+ @property
210
+ def prefix_id(self):
211
+ if self._prefix_token is None:
212
+ return None
213
+ return self.convert_tokens_to_ids(self.prefix_token)
214
+
215
+ @property
216
+ def middle_token(self):
217
+ return self._middle_token
218
+
219
+ @property
220
+ def middle_id(self):
221
+ if self._middle_token is None:
222
+ return None
223
+ return self.convert_tokens_to_ids(self.middle_token)
224
+
225
+ @property
226
+ def suffix_token(self):
227
+ return self._suffix_token
228
+
229
+ @property
230
+ def suffix_id(self):
231
+ if self._suffix_token is None:
232
+ return None
233
+ return self.convert_tokens_to_ids(self.suffix_token)
234
+
235
+ @property
236
+ def eot_id(self):
237
+ if self._eot_token is None:
238
+ return None
239
+ return self.convert_tokens_to_ids(self.eot_token)
240
+
241
+ @property
242
+ def eot_token(self):
243
+ return self._eot_token
244
+
245
+ @property
246
+ def add_eos_token(self):
247
+ return self._add_eos_token
248
+
249
+ @property
250
+ def add_bos_token(self):
251
+ return self._add_bos_token
252
+
253
+ @add_eos_token.setter
254
+ def add_eos_token(self, value):
255
+ self._add_eos_token = value
256
+ self.update_post_processor()
257
+
258
+ @add_bos_token.setter
259
+ def add_bos_token(self, value):
260
+ self._add_bos_token = value
261
+ self.update_post_processor()
262
+
263
+ def set_infilling_processor(self, reset, suffix_first=False, add_special_tokens=True):
264
+ """
265
+ Updates the normalizer to make sure the prompt format for `infilling` is respected. The infilling format is the
266
+ following: if suffix_first
267
+ " <PRE> <SUF>{suf} <MID> {pre}"
268
+ else:
269
+ " <PRE> {pre} <SUF>{suf} <MID>"
270
+
271
+ If `reset` is set to `True`, the `normalizer` and `post_processor` are reset to their "normal" behaviour, which
272
+ is to add a prefix space for the normalizer, and add a `bos_token` to the input text for the `post_processor`.
273
+ """
274
+ if reset:
275
+ self._tokenizer.normalizer = normalizers.Sequence(
276
+ [
277
+ normalizers.Prepend(prepend="▁"),
278
+ normalizers.Replace(pattern=" ", content="▁"),
279
+ ]
280
+ )
281
+ self.update_post_processor()
282
+ return
283
+
284
+ self._tokenizer.normalizer = normalizers.Replace(pattern=" ", content="▁")
285
+ pair = [self.bos_token] if self.add_bos_token and add_special_tokens else []
286
+ special_tokens = [(self.bos_token, self.bos_token_id)] if self.add_bos_token and add_special_tokens else []
287
+ if suffix_first:
288
+ # format as " <PRE> <SUF>{suf} <MID> {pre}"
289
+ pair += [self.prefix_token, self.suffix_token, "$B", self.middle_token, "$A"]
290
+ special_tokens += [
291
+ (self.prefix_token, self.prefix_id),
292
+ (self.suffix_token, self.suffix_id),
293
+ (self.middle_token, self.middle_id),
294
+ ]
295
+ else:
296
+ # format as " <PRE> {pre} <SUF>{suf} <MID>"
297
+ pair += [self.prefix_token, "$A", self.suffix_token, "$B", self.middle_token]
298
+ special_tokens += [
299
+ (self.prefix_token, self.prefix_id),
300
+ (self.suffix_token, self.suffix_id),
301
+ (self.middle_token, self.middle_id),
302
+ ]
303
+
304
+ if self.add_eos_token and add_special_tokens:
305
+ pair += [self.eos_token]
306
+ special_tokens += [(self.eos_token, self.eos_token_id)]
307
+ self._tokenizer.post_processor = processors.TemplateProcessing(
308
+ single="$A", pair=pair, special_tokens=special_tokens
309
+ )
310
+
311
+ def encode_plus(self, text, text_pair=None, suffix_first=False, add_special_tokens=True, **kwargs):
312
+ # hack to make sure the input is pre-process but outside rust
313
+ text_pair = kwargs.pop("suffix", text_pair)
314
+ if self.fill_token is not None and self.fill_token in text and text_pair is None:
315
+ text, text_pair = text.split(self.fill_token)
316
+
317
+ if text_pair is None or len(text_pair) < 1:
318
+ return super().encode_plus(text, text_pair, add_special_tokens=add_special_tokens, **kwargs)
319
+
320
+ if None in (self.prefix_id, self.middle_id, self.suffix_id):
321
+ raise ValueError(
322
+ "Then input includes a `prefix` and a `suffix` used for the infilling task,"
323
+ " the `prefix_id, middle_id, suffix_id` must all be initialized. Current"
324
+ f" values : {self.prefix_id, self.middle_id, self.suffix_id}"
325
+ )
326
+
327
+ self.set_infilling_processor(False, suffix_first=suffix_first, add_special_tokens=add_special_tokens)
328
+ tokens = super().encode_plus(" " + text, text_pair=text_pair, add_special_tokens=True, **kwargs)
329
+ self.set_infilling_processor(True)
330
+ return tokens
331
+
332
+ # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.save_vocabulary
333
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
334
+ if not self.can_save_slow_tokenizer:
335
+ raise ValueError(
336
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
337
+ "tokenizer."
338
+ )
339
+
340
+ if not os.path.isdir(save_directory):
341
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
342
+ return
343
+ out_vocab_file = os.path.join(
344
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
345
+ )
346
+
347
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
348
+ copyfile(self.vocab_file, out_vocab_file)
349
+
350
+ return (out_vocab_file,)
351
+
352
+ def build_inputs_with_special_tokens(
353
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
354
+ ) -> List[int]:
355
+ """
356
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
357
+ adding special tokens. The special tokens depend on calling set_lang.
358
+
359
+ An NLLB sequence has the following format, where `X` represents the sequence:
360
+
361
+ - `input_ids` (for encoder) `X [eos, src_lang_code]`
362
+ - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
363
+
364
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
365
+ separator.
366
+
367
+ Args:
368
+ token_ids_0 (`List[int]`):
369
+ List of IDs to which the special tokens will be added.
370
+ token_ids_1 (`List[int]`, *optional*):
371
+ Optional second list of IDs for sequence pairs.
372
+
373
+ Returns:
374
+ `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
375
+ """
376
+ if token_ids_1 is None:
377
+ return self.bos_token_id + token_ids_0 + self.eos_token_id
378
+ return self.bos_token_id + token_ids_0 + token_ids_1 + self.eos_token_id
379
+
380
+
381
+ __all__ = ["CodeLlamaTokenizerFast"]
.venv/lib/python3.11/site-packages/transformers/models/convnext/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_convnext import *
22
+ from .feature_extraction_convnext import *
23
+ from .image_processing_convnext import *
24
+ from .modeling_convnext import *
25
+ from .modeling_tf_convnext import *
26
+ else:
27
+ import sys
28
+
29
+ _file = globals()["__file__"]
30
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/configuration_convnext.cpython-311.pyc ADDED
Binary file (7.08 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/convnext/__pycache__/image_processing_convnext.cpython-311.pyc ADDED
Binary file (16.5 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/convnext/configuration_convnext.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ConvNeXT model configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import Mapping
19
+
20
+ from packaging import version
21
+
22
+ from ...configuration_utils import PretrainedConfig
23
+ from ...onnx import OnnxConfig
24
+ from ...utils import logging
25
+ from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class ConvNextConfig(BackboneConfigMixin, PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`ConvNextModel`]. It is used to instantiate an
34
+ ConvNeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration
35
+ with the defaults will yield a similar configuration to that of the ConvNeXT
36
+ [facebook/convnext-tiny-224](https://huggingface.co/facebook/convnext-tiny-224) architecture.
37
+
38
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
+ documentation from [`PretrainedConfig`] for more information.
40
+
41
+ Args:
42
+ num_channels (`int`, *optional*, defaults to 3):
43
+ The number of input channels.
44
+ patch_size (`int`, *optional*, defaults to 4):
45
+ Patch size to use in the patch embedding layer.
46
+ num_stages (`int`, *optional*, defaults to 4):
47
+ The number of stages in the model.
48
+ hidden_sizes (`List[int]`, *optional*, defaults to [96, 192, 384, 768]):
49
+ Dimensionality (hidden size) at each stage.
50
+ depths (`List[int]`, *optional*, defaults to [3, 3, 9, 3]):
51
+ Depth (number of blocks) for each stage.
52
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
53
+ The non-linear activation function (function or string) in each block. If string, `"gelu"`, `"relu"`,
54
+ `"selu"` and `"gelu_new"` are supported.
55
+ initializer_range (`float`, *optional*, defaults to 0.02):
56
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
57
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
58
+ The epsilon used by the layer normalization layers.
59
+ layer_scale_init_value (`float`, *optional*, defaults to 1e-6):
60
+ The initial value for the layer scale.
61
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
62
+ The drop rate for stochastic depth.
63
+ out_features (`List[str]`, *optional*):
64
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
65
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
66
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
67
+ same order as defined in the `stage_names` attribute.
68
+ out_indices (`List[int]`, *optional*):
69
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
70
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
71
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
72
+ same order as defined in the `stage_names` attribute.
73
+
74
+ Example:
75
+ ```python
76
+ >>> from transformers import ConvNextConfig, ConvNextModel
77
+
78
+ >>> # Initializing a ConvNext convnext-tiny-224 style configuration
79
+ >>> configuration = ConvNextConfig()
80
+
81
+ >>> # Initializing a model (with random weights) from the convnext-tiny-224 style configuration
82
+ >>> model = ConvNextModel(configuration)
83
+
84
+ >>> # Accessing the model configuration
85
+ >>> configuration = model.config
86
+ ```"""
87
+
88
+ model_type = "convnext"
89
+
90
+ def __init__(
91
+ self,
92
+ num_channels=3,
93
+ patch_size=4,
94
+ num_stages=4,
95
+ hidden_sizes=None,
96
+ depths=None,
97
+ hidden_act="gelu",
98
+ initializer_range=0.02,
99
+ layer_norm_eps=1e-12,
100
+ layer_scale_init_value=1e-6,
101
+ drop_path_rate=0.0,
102
+ image_size=224,
103
+ out_features=None,
104
+ out_indices=None,
105
+ **kwargs,
106
+ ):
107
+ super().__init__(**kwargs)
108
+
109
+ self.num_channels = num_channels
110
+ self.patch_size = patch_size
111
+ self.num_stages = num_stages
112
+ self.hidden_sizes = [96, 192, 384, 768] if hidden_sizes is None else hidden_sizes
113
+ self.depths = [3, 3, 9, 3] if depths is None else depths
114
+ self.hidden_act = hidden_act
115
+ self.initializer_range = initializer_range
116
+ self.layer_norm_eps = layer_norm_eps
117
+ self.layer_scale_init_value = layer_scale_init_value
118
+ self.drop_path_rate = drop_path_rate
119
+ self.image_size = image_size
120
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
121
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
122
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
123
+ )
124
+
125
+
126
+ class ConvNextOnnxConfig(OnnxConfig):
127
+ torch_onnx_minimum_version = version.parse("1.11")
128
+
129
+ @property
130
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
131
+ return OrderedDict(
132
+ [
133
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
134
+ ]
135
+ )
136
+
137
+ @property
138
+ def atol_for_validation(self) -> float:
139
+ return 1e-5
140
+
141
+
142
+ __all__ = ["ConvNextConfig", "ConvNextOnnxConfig"]
.venv/lib/python3.11/site-packages/transformers/models/convnext/feature_extraction_convnext.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Feature extractor class for ConvNeXT."""
16
+
17
+ import warnings
18
+
19
+ from ...utils import logging
20
+ from .image_processing_convnext import ConvNextImageProcessor
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class ConvNextFeatureExtractor(ConvNextImageProcessor):
27
+ def __init__(self, *args, **kwargs) -> None:
28
+ warnings.warn(
29
+ "The class ConvNextFeatureExtractor is deprecated and will be removed in version 5 of Transformers."
30
+ " Please use ConvNextImageProcessor instead.",
31
+ FutureWarning,
32
+ )
33
+ super().__init__(*args, **kwargs)
34
+
35
+
36
+ __all__ = ["ConvNextFeatureExtractor"]
.venv/lib/python3.11/site-packages/transformers/models/convnext/image_processing_convnext.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for ConvNeXT."""
16
+
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
22
+ from ...image_transforms import (
23
+ center_crop,
24
+ get_resize_output_image_size,
25
+ resize,
26
+ to_channel_dimension_format,
27
+ )
28
+ from ...image_utils import (
29
+ IMAGENET_STANDARD_MEAN,
30
+ IMAGENET_STANDARD_STD,
31
+ ChannelDimension,
32
+ ImageInput,
33
+ PILImageResampling,
34
+ infer_channel_dimension_format,
35
+ is_scaled_image,
36
+ make_list_of_images,
37
+ to_numpy_array,
38
+ valid_images,
39
+ validate_preprocess_arguments,
40
+ )
41
+ from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
42
+
43
+
44
+ if is_vision_available():
45
+ import PIL
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ class ConvNextImageProcessor(BaseImageProcessor):
52
+ r"""
53
+ Constructs a ConvNeXT image processor.
54
+
55
+ Args:
56
+ do_resize (`bool`, *optional*, defaults to `True`):
57
+ Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden
58
+ by `do_resize` in the `preprocess` method.
59
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 384}`):
60
+ Resolution of the output image after `resize` is applied. If `size["shortest_edge"]` >= 384, the image is
61
+ resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the image will
62
+ be matched to `int(size["shortest_edge"]/crop_pct)`, after which the image is cropped to
63
+ `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. Can
64
+ be overriden by `size` in the `preprocess` method.
65
+ crop_pct (`float` *optional*, defaults to 224 / 256):
66
+ Percentage of the image to crop. Only has an effect if `do_resize` is `True` and size < 384. Can be
67
+ overriden by `crop_pct` in the `preprocess` method.
68
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
69
+ Resampling filter to use if resizing the image. Can be overriden by `resample` in the `preprocess` method.
70
+ do_rescale (`bool`, *optional*, defaults to `True`):
71
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in
72
+ the `preprocess` method.
73
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
74
+ Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
75
+ method.
76
+ do_normalize (`bool`, *optional*, defaults to `True`):
77
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
78
+ method.
79
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
80
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
81
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
82
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
83
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
84
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
85
+ """
86
+
87
+ model_input_names = ["pixel_values"]
88
+
89
+ def __init__(
90
+ self,
91
+ do_resize: bool = True,
92
+ size: Dict[str, int] = None,
93
+ crop_pct: float = None,
94
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
95
+ do_rescale: bool = True,
96
+ rescale_factor: Union[int, float] = 1 / 255,
97
+ do_normalize: bool = True,
98
+ image_mean: Optional[Union[float, List[float]]] = None,
99
+ image_std: Optional[Union[float, List[float]]] = None,
100
+ **kwargs,
101
+ ) -> None:
102
+ super().__init__(**kwargs)
103
+ size = size if size is not None else {"shortest_edge": 384}
104
+ size = get_size_dict(size, default_to_square=False)
105
+
106
+ self.do_resize = do_resize
107
+ self.size = size
108
+ # Default value set here for backwards compatibility where the value in config is None
109
+ self.crop_pct = crop_pct if crop_pct is not None else 224 / 256
110
+ self.resample = resample
111
+ self.do_rescale = do_rescale
112
+ self.rescale_factor = rescale_factor
113
+ self.do_normalize = do_normalize
114
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
115
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
116
+
117
+ def resize(
118
+ self,
119
+ image: np.ndarray,
120
+ size: Dict[str, int],
121
+ crop_pct: float,
122
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
123
+ data_format: Optional[Union[str, ChannelDimension]] = None,
124
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
125
+ **kwargs,
126
+ ) -> np.ndarray:
127
+ """
128
+ Resize an image.
129
+
130
+ Args:
131
+ image (`np.ndarray`):
132
+ Image to resize.
133
+ size (`Dict[str, int]`):
134
+ Dictionary of the form `{"shortest_edge": int}`, specifying the size of the output image. If
135
+ `size["shortest_edge"]` >= 384 image is resized to `(size["shortest_edge"], size["shortest_edge"])`.
136
+ Otherwise, the smaller edge of the image will be matched to `int(size["shortest_edge"] / crop_pct)`,
137
+ after which the image is cropped to `(size["shortest_edge"], size["shortest_edge"])`.
138
+ crop_pct (`float`):
139
+ Percentage of the image to crop. Only has an effect if size < 384.
140
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
141
+ Resampling filter to use when resizing the image.
142
+ data_format (`str` or `ChannelDimension`, *optional*):
143
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
144
+ input_data_format (`ChannelDimension` or `str`, *optional*):
145
+ The channel dimension format of the input image. If not provided, it will be inferred from the input
146
+ image.
147
+ """
148
+ size = get_size_dict(size, default_to_square=False)
149
+ if "shortest_edge" not in size:
150
+ raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {size.keys()}")
151
+ shortest_edge = size["shortest_edge"]
152
+
153
+ if shortest_edge < 384:
154
+ # maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
155
+ resize_shortest_edge = int(shortest_edge / crop_pct)
156
+ resize_size = get_resize_output_image_size(
157
+ image, size=resize_shortest_edge, default_to_square=False, input_data_format=input_data_format
158
+ )
159
+ image = resize(
160
+ image=image,
161
+ size=resize_size,
162
+ resample=resample,
163
+ data_format=data_format,
164
+ input_data_format=input_data_format,
165
+ **kwargs,
166
+ )
167
+ # then crop to (shortest_edge, shortest_edge)
168
+ return center_crop(
169
+ image=image,
170
+ size=(shortest_edge, shortest_edge),
171
+ data_format=data_format,
172
+ input_data_format=input_data_format,
173
+ **kwargs,
174
+ )
175
+ else:
176
+ # warping (no cropping) when evaluated at 384 or larger
177
+ return resize(
178
+ image,
179
+ size=(shortest_edge, shortest_edge),
180
+ resample=resample,
181
+ data_format=data_format,
182
+ input_data_format=input_data_format,
183
+ **kwargs,
184
+ )
185
+
186
+ @filter_out_non_signature_kwargs()
187
+ def preprocess(
188
+ self,
189
+ images: ImageInput,
190
+ do_resize: bool = None,
191
+ size: Dict[str, int] = None,
192
+ crop_pct: float = None,
193
+ resample: PILImageResampling = None,
194
+ do_rescale: bool = None,
195
+ rescale_factor: float = None,
196
+ do_normalize: bool = None,
197
+ image_mean: Optional[Union[float, List[float]]] = None,
198
+ image_std: Optional[Union[float, List[float]]] = None,
199
+ return_tensors: Optional[Union[str, TensorType]] = None,
200
+ data_format: ChannelDimension = ChannelDimension.FIRST,
201
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
202
+ ) -> PIL.Image.Image:
203
+ """
204
+ Preprocess an image or batch of images.
205
+
206
+ Args:
207
+ images (`ImageInput`):
208
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
209
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
210
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
211
+ Whether to resize the image.
212
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
213
+ Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image
214
+ is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the
215
+ image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to
216
+ `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`.
217
+ crop_pct (`float`, *optional*, defaults to `self.crop_pct`):
218
+ Percentage of the image to crop if size < 384.
219
+ resample (`int`, *optional*, defaults to `self.resample`):
220
+ Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only
221
+ has an effect if `do_resize` is set to `True`.
222
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
223
+ Whether to rescale the image values between [0 - 1].
224
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
225
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
226
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
227
+ Whether to normalize the image.
228
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
229
+ Image mean.
230
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
231
+ Image standard deviation.
232
+ return_tensors (`str` or `TensorType`, *optional*):
233
+ The type of tensors to return. Can be one of:
234
+ - Unset: Return a list of `np.ndarray`.
235
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
236
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
237
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
238
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
239
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
240
+ The channel dimension format for the output image. Can be one of:
241
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
242
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
243
+ - Unset: Use the channel dimension format of the input image.
244
+ input_data_format (`ChannelDimension` or `str`, *optional*):
245
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
246
+ from the input image. Can be one of:
247
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
248
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
249
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
250
+ """
251
+ do_resize = do_resize if do_resize is not None else self.do_resize
252
+ crop_pct = crop_pct if crop_pct is not None else self.crop_pct
253
+ resample = resample if resample is not None else self.resample
254
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
255
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
256
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
257
+ image_mean = image_mean if image_mean is not None else self.image_mean
258
+ image_std = image_std if image_std is not None else self.image_std
259
+
260
+ size = size if size is not None else self.size
261
+ size = get_size_dict(size, default_to_square=False)
262
+
263
+ images = make_list_of_images(images)
264
+
265
+ if not valid_images(images):
266
+ raise ValueError(
267
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
268
+ "torch.Tensor, tf.Tensor or jax.ndarray."
269
+ )
270
+
271
+ validate_preprocess_arguments(
272
+ do_rescale=do_rescale,
273
+ rescale_factor=rescale_factor,
274
+ do_normalize=do_normalize,
275
+ image_mean=image_mean,
276
+ image_std=image_std,
277
+ do_resize=do_resize,
278
+ size=size,
279
+ resample=resample,
280
+ )
281
+
282
+ # All transformations expect numpy arrays.
283
+ images = [to_numpy_array(image) for image in images]
284
+
285
+ if do_rescale and is_scaled_image(images[0]):
286
+ logger.warning_once(
287
+ "It looks like you are trying to rescale already rescaled images. If the input"
288
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
289
+ )
290
+
291
+ if input_data_format is None:
292
+ # We assume that all images have the same channel dimension format.
293
+ input_data_format = infer_channel_dimension_format(images[0])
294
+
295
+ if do_resize:
296
+ images = [
297
+ self.resize(
298
+ image=image, size=size, crop_pct=crop_pct, resample=resample, input_data_format=input_data_format
299
+ )
300
+ for image in images
301
+ ]
302
+
303
+ if do_rescale:
304
+ images = [
305
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
306
+ for image in images
307
+ ]
308
+
309
+ if do_normalize:
310
+ images = [
311
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
312
+ for image in images
313
+ ]
314
+
315
+ images = [
316
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
317
+ ]
318
+
319
+ data = {"pixel_values": images}
320
+ return BatchFeature(data=data, tensor_type=return_tensors)
321
+
322
+
323
+ __all__ = ["ConvNextImageProcessor"]
.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_convnext.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch ConvNext model."""
16
+
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
+
24
+ from ...activations import ACT2FN
25
+ from ...modeling_outputs import (
26
+ BackboneOutput,
27
+ BaseModelOutputWithNoAttention,
28
+ BaseModelOutputWithPoolingAndNoAttention,
29
+ ImageClassifierOutputWithNoAttention,
30
+ )
31
+ from ...modeling_utils import PreTrainedModel
32
+ from ...utils import (
33
+ add_code_sample_docstrings,
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ logging,
37
+ replace_return_docstrings,
38
+ )
39
+ from ...utils.backbone_utils import BackboneMixin
40
+ from .configuration_convnext import ConvNextConfig
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ # General docstring
46
+ _CONFIG_FOR_DOC = "ConvNextConfig"
47
+
48
+ # Base docstring
49
+ _CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224"
50
+ _EXPECTED_OUTPUT_SHAPE = [1, 768, 7, 7]
51
+
52
+ # Image classification docstring
53
+ _IMAGE_CLASS_CHECKPOINT = "facebook/convnext-tiny-224"
54
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
55
+
56
+
57
+ # Copied from transformers.models.beit.modeling_beit.drop_path
58
+ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
59
+ """
60
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
61
+
62
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
63
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
64
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
65
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
66
+ argument.
67
+ """
68
+ if drop_prob == 0.0 or not training:
69
+ return input
70
+ keep_prob = 1 - drop_prob
71
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
72
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
73
+ random_tensor.floor_() # binarize
74
+ output = input.div(keep_prob) * random_tensor
75
+ return output
76
+
77
+
78
+ # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNext
79
+ class ConvNextDropPath(nn.Module):
80
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
81
+
82
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
83
+ super().__init__()
84
+ self.drop_prob = drop_prob
85
+
86
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
87
+ return drop_path(hidden_states, self.drop_prob, self.training)
88
+
89
+ def extra_repr(self) -> str:
90
+ return "p={}".format(self.drop_prob)
91
+
92
+
93
+ class ConvNextLayerNorm(nn.Module):
94
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
95
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
96
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
97
+ """
98
+
99
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
100
+ super().__init__()
101
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
102
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
103
+ self.eps = eps
104
+ self.data_format = data_format
105
+ if self.data_format not in ["channels_last", "channels_first"]:
106
+ raise NotImplementedError(f"Unsupported data format: {self.data_format}")
107
+ self.normalized_shape = (normalized_shape,)
108
+
109
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
110
+ if self.data_format == "channels_last":
111
+ x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
112
+ elif self.data_format == "channels_first":
113
+ input_dtype = x.dtype
114
+ x = x.float()
115
+ u = x.mean(1, keepdim=True)
116
+ s = (x - u).pow(2).mean(1, keepdim=True)
117
+ x = (x - u) / torch.sqrt(s + self.eps)
118
+ x = x.to(dtype=input_dtype)
119
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
120
+ return x
121
+
122
+
123
+ class ConvNextEmbeddings(nn.Module):
124
+ """This class is comparable to (and inspired by) the SwinEmbeddings class
125
+ found in src/transformers/models/swin/modeling_swin.py.
126
+ """
127
+
128
+ def __init__(self, config):
129
+ super().__init__()
130
+ self.patch_embeddings = nn.Conv2d(
131
+ config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
132
+ )
133
+ self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
134
+ self.num_channels = config.num_channels
135
+
136
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
137
+ num_channels = pixel_values.shape[1]
138
+ if num_channels != self.num_channels:
139
+ raise ValueError(
140
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
141
+ )
142
+ embeddings = self.patch_embeddings(pixel_values)
143
+ embeddings = self.layernorm(embeddings)
144
+ return embeddings
145
+
146
+
147
+ class ConvNextLayer(nn.Module):
148
+ """This corresponds to the `Block` class in the original implementation.
149
+
150
+ There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
151
+ H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
152
+
153
+ The authors used (2) as they find it slightly faster in PyTorch.
154
+
155
+ Args:
156
+ config ([`ConvNextConfig`]): Model configuration class.
157
+ dim (`int`): Number of input channels.
158
+ drop_path (`float`): Stochastic depth rate. Default: 0.0.
159
+ """
160
+
161
+ def __init__(self, config, dim, drop_path=0):
162
+ super().__init__()
163
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
164
+ self.layernorm = ConvNextLayerNorm(dim, eps=1e-6)
165
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
166
+ self.act = ACT2FN[config.hidden_act]
167
+ self.pwconv2 = nn.Linear(4 * dim, dim)
168
+ self.layer_scale_parameter = (
169
+ nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
170
+ if config.layer_scale_init_value > 0
171
+ else None
172
+ )
173
+ self.drop_path = ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
174
+
175
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
176
+ input = hidden_states
177
+ x = self.dwconv(hidden_states)
178
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
179
+ x = self.layernorm(x)
180
+ x = self.pwconv1(x)
181
+ x = self.act(x)
182
+ x = self.pwconv2(x)
183
+ if self.layer_scale_parameter is not None:
184
+ x = self.layer_scale_parameter * x
185
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
186
+
187
+ x = input + self.drop_path(x)
188
+ return x
189
+
190
+
191
+ class ConvNextStage(nn.Module):
192
+ """ConvNeXT stage, consisting of an optional downsampling layer + multiple residual blocks.
193
+
194
+ Args:
195
+ config ([`ConvNextConfig`]): Model configuration class.
196
+ in_channels (`int`): Number of input channels.
197
+ out_channels (`int`): Number of output channels.
198
+ depth (`int`): Number of residual blocks.
199
+ drop_path_rates(`List[float]`): Stochastic depth rates for each layer.
200
+ """
201
+
202
+ def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None):
203
+ super().__init__()
204
+
205
+ if in_channels != out_channels or stride > 1:
206
+ self.downsampling_layer = nn.Sequential(
207
+ ConvNextLayerNorm(in_channels, eps=1e-6, data_format="channels_first"),
208
+ nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
209
+ )
210
+ else:
211
+ self.downsampling_layer = nn.Identity()
212
+ drop_path_rates = drop_path_rates or [0.0] * depth
213
+ self.layers = nn.Sequential(
214
+ *[ConvNextLayer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
215
+ )
216
+
217
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
218
+ hidden_states = self.downsampling_layer(hidden_states)
219
+ hidden_states = self.layers(hidden_states)
220
+ return hidden_states
221
+
222
+
223
+ class ConvNextEncoder(nn.Module):
224
+ def __init__(self, config):
225
+ super().__init__()
226
+ self.stages = nn.ModuleList()
227
+ drop_path_rates = [
228
+ x.tolist() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths)).split(config.depths)
229
+ ]
230
+ prev_chs = config.hidden_sizes[0]
231
+ for i in range(config.num_stages):
232
+ out_chs = config.hidden_sizes[i]
233
+ stage = ConvNextStage(
234
+ config,
235
+ in_channels=prev_chs,
236
+ out_channels=out_chs,
237
+ stride=2 if i > 0 else 1,
238
+ depth=config.depths[i],
239
+ drop_path_rates=drop_path_rates[i],
240
+ )
241
+ self.stages.append(stage)
242
+ prev_chs = out_chs
243
+
244
+ def forward(
245
+ self,
246
+ hidden_states: torch.FloatTensor,
247
+ output_hidden_states: Optional[bool] = False,
248
+ return_dict: Optional[bool] = True,
249
+ ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
250
+ all_hidden_states = () if output_hidden_states else None
251
+
252
+ for i, layer_module in enumerate(self.stages):
253
+ if output_hidden_states:
254
+ all_hidden_states = all_hidden_states + (hidden_states,)
255
+
256
+ hidden_states = layer_module(hidden_states)
257
+
258
+ if output_hidden_states:
259
+ all_hidden_states = all_hidden_states + (hidden_states,)
260
+
261
+ if not return_dict:
262
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
263
+
264
+ return BaseModelOutputWithNoAttention(
265
+ last_hidden_state=hidden_states,
266
+ hidden_states=all_hidden_states,
267
+ )
268
+
269
+
270
+ class ConvNextPreTrainedModel(PreTrainedModel):
271
+ """
272
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
273
+ models.
274
+ """
275
+
276
+ config_class = ConvNextConfig
277
+ base_model_prefix = "convnext"
278
+ main_input_name = "pixel_values"
279
+ _no_split_modules = ["ConvNextLayer"]
280
+
281
+ def _init_weights(self, module):
282
+ """Initialize the weights"""
283
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
284
+ # Slightly different from the TF version which uses truncated_normal for initialization
285
+ # cf https://github.com/pytorch/pytorch/pull/5617
286
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
287
+ if module.bias is not None:
288
+ module.bias.data.zero_()
289
+ elif isinstance(module, nn.LayerNorm):
290
+ module.bias.data.zero_()
291
+ module.weight.data.fill_(1.0)
292
+
293
+
294
+ CONVNEXT_START_DOCSTRING = r"""
295
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
296
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
297
+ behavior.
298
+
299
+ Parameters:
300
+ config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model.
301
+ Initializing with a config file does not load the weights associated with the model, only the
302
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
303
+ """
304
+
305
+ CONVNEXT_INPUTS_DOCSTRING = r"""
306
+ Args:
307
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
308
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
309
+ [`ConvNextImageProcessor.__call__`] for details.
310
+
311
+ output_hidden_states (`bool`, *optional*):
312
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
313
+ more detail.
314
+ return_dict (`bool`, *optional*):
315
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
316
+ """
317
+
318
+
319
+ @add_start_docstrings(
320
+ "The bare ConvNext model outputting raw features without any specific head on top.",
321
+ CONVNEXT_START_DOCSTRING,
322
+ )
323
+ class ConvNextModel(ConvNextPreTrainedModel):
324
+ def __init__(self, config):
325
+ super().__init__(config)
326
+ self.config = config
327
+
328
+ self.embeddings = ConvNextEmbeddings(config)
329
+ self.encoder = ConvNextEncoder(config)
330
+
331
+ # final layernorm layer
332
+ self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
333
+
334
+ # Initialize weights and apply final processing
335
+ self.post_init()
336
+
337
+ @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
338
+ @add_code_sample_docstrings(
339
+ checkpoint=_CHECKPOINT_FOR_DOC,
340
+ output_type=BaseModelOutputWithPoolingAndNoAttention,
341
+ config_class=_CONFIG_FOR_DOC,
342
+ modality="vision",
343
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
344
+ )
345
+ def forward(
346
+ self,
347
+ pixel_values: torch.FloatTensor = None,
348
+ output_hidden_states: Optional[bool] = None,
349
+ return_dict: Optional[bool] = None,
350
+ ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
351
+ output_hidden_states = (
352
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
353
+ )
354
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
355
+
356
+ if pixel_values is None:
357
+ raise ValueError("You have to specify pixel_values")
358
+
359
+ embedding_output = self.embeddings(pixel_values)
360
+
361
+ encoder_outputs = self.encoder(
362
+ embedding_output,
363
+ output_hidden_states=output_hidden_states,
364
+ return_dict=return_dict,
365
+ )
366
+
367
+ last_hidden_state = encoder_outputs[0]
368
+
369
+ # global average pooling, (N, C, H, W) -> (N, C)
370
+ pooled_output = self.layernorm(last_hidden_state.mean([-2, -1]))
371
+
372
+ if not return_dict:
373
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
374
+
375
+ return BaseModelOutputWithPoolingAndNoAttention(
376
+ last_hidden_state=last_hidden_state,
377
+ pooler_output=pooled_output,
378
+ hidden_states=encoder_outputs.hidden_states,
379
+ )
380
+
381
+
382
+ @add_start_docstrings(
383
+ """
384
+ ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
385
+ ImageNet.
386
+ """,
387
+ CONVNEXT_START_DOCSTRING,
388
+ )
389
+ class ConvNextForImageClassification(ConvNextPreTrainedModel):
390
+ def __init__(self, config):
391
+ super().__init__(config)
392
+
393
+ self.num_labels = config.num_labels
394
+ self.convnext = ConvNextModel(config)
395
+
396
+ # Classifier head
397
+ self.classifier = (
398
+ nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
399
+ )
400
+
401
+ # Initialize weights and apply final processing
402
+ self.post_init()
403
+
404
+ @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
405
+ @add_code_sample_docstrings(
406
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
407
+ output_type=ImageClassifierOutputWithNoAttention,
408
+ config_class=_CONFIG_FOR_DOC,
409
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
410
+ )
411
+ def forward(
412
+ self,
413
+ pixel_values: torch.FloatTensor = None,
414
+ labels: Optional[torch.LongTensor] = None,
415
+ output_hidden_states: Optional[bool] = None,
416
+ return_dict: Optional[bool] = None,
417
+ ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
418
+ r"""
419
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
420
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
421
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
422
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
423
+ """
424
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
425
+
426
+ outputs = self.convnext(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
427
+
428
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
429
+
430
+ logits = self.classifier(pooled_output)
431
+
432
+ loss = None
433
+ if labels is not None:
434
+ if self.config.problem_type is None:
435
+ if self.num_labels == 1:
436
+ self.config.problem_type = "regression"
437
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
438
+ self.config.problem_type = "single_label_classification"
439
+ else:
440
+ self.config.problem_type = "multi_label_classification"
441
+
442
+ if self.config.problem_type == "regression":
443
+ loss_fct = MSELoss()
444
+ if self.num_labels == 1:
445
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
446
+ else:
447
+ loss = loss_fct(logits, labels)
448
+ elif self.config.problem_type == "single_label_classification":
449
+ loss_fct = CrossEntropyLoss()
450
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
451
+ elif self.config.problem_type == "multi_label_classification":
452
+ loss_fct = BCEWithLogitsLoss()
453
+ loss = loss_fct(logits, labels)
454
+ if not return_dict:
455
+ output = (logits,) + outputs[2:]
456
+ return ((loss,) + output) if loss is not None else output
457
+
458
+ return ImageClassifierOutputWithNoAttention(
459
+ loss=loss,
460
+ logits=logits,
461
+ hidden_states=outputs.hidden_states,
462
+ )
463
+
464
+
465
+ @add_start_docstrings(
466
+ """
467
+ ConvNeXt backbone, to be used with frameworks like DETR and MaskFormer.
468
+ """,
469
+ CONVNEXT_START_DOCSTRING,
470
+ )
471
+ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
472
+ def __init__(self, config):
473
+ super().__init__(config)
474
+ super()._init_backbone(config)
475
+
476
+ self.embeddings = ConvNextEmbeddings(config)
477
+ self.encoder = ConvNextEncoder(config)
478
+ self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
479
+
480
+ # Add layer norms to hidden states of out_features
481
+ hidden_states_norms = {}
482
+ for stage, num_channels in zip(self._out_features, self.channels):
483
+ hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first")
484
+ self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
485
+
486
+ # initialize weights and apply final processing
487
+ self.post_init()
488
+
489
+ @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
490
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
491
+ def forward(
492
+ self,
493
+ pixel_values: torch.Tensor,
494
+ output_hidden_states: Optional[bool] = None,
495
+ return_dict: Optional[bool] = None,
496
+ ) -> BackboneOutput:
497
+ """
498
+ Returns:
499
+
500
+ Examples:
501
+
502
+ ```python
503
+ >>> from transformers import AutoImageProcessor, AutoBackbone
504
+ >>> import torch
505
+ >>> from PIL import Image
506
+ >>> import requests
507
+
508
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
509
+ >>> image = Image.open(requests.get(url, stream=True).raw)
510
+
511
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
512
+ >>> model = AutoBackbone.from_pretrained("facebook/convnext-tiny-224")
513
+
514
+ >>> inputs = processor(image, return_tensors="pt")
515
+ >>> outputs = model(**inputs)
516
+ ```"""
517
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
518
+ output_hidden_states = (
519
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
520
+ )
521
+
522
+ embedding_output = self.embeddings(pixel_values)
523
+
524
+ outputs = self.encoder(
525
+ embedding_output,
526
+ output_hidden_states=True,
527
+ return_dict=return_dict,
528
+ )
529
+
530
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
531
+
532
+ feature_maps = ()
533
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
534
+ if stage in self.out_features:
535
+ hidden_state = self.hidden_states_norms[stage](hidden_state)
536
+ feature_maps += (hidden_state,)
537
+
538
+ if not return_dict:
539
+ output = (feature_maps,)
540
+ if output_hidden_states:
541
+ output += (hidden_states,)
542
+ return output
543
+
544
+ return BackboneOutput(
545
+ feature_maps=feature_maps,
546
+ hidden_states=hidden_states if output_hidden_states else None,
547
+ attentions=None,
548
+ )
549
+
550
+
551
+ __all__ = ["ConvNextForImageClassification", "ConvNextModel", "ConvNextPreTrainedModel", "ConvNextBackbone"]
.venv/lib/python3.11/site-packages/transformers/models/convnext/modeling_tf_convnext.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Meta Platforms Inc. and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """TF 2.0 ConvNext model."""
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import tensorflow as tf
23
+
24
+ from ...activations_tf import get_tf_activation
25
+ from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
26
+ from ...modeling_tf_utils import (
27
+ TFModelInputType,
28
+ TFPreTrainedModel,
29
+ TFSequenceClassificationLoss,
30
+ get_initializer,
31
+ keras,
32
+ keras_serializable,
33
+ unpack_inputs,
34
+ )
35
+ from ...tf_utils import shape_list
36
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
37
+ from .configuration_convnext import ConvNextConfig
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ _CONFIG_FOR_DOC = "ConvNextConfig"
44
+ _CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224"
45
+
46
+
47
+ class TFConvNextDropPath(keras.layers.Layer):
48
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
49
+ References:
50
+ (1) github.com:rwightman/pytorch-image-models
51
+ """
52
+
53
+ def __init__(self, drop_path: float, **kwargs):
54
+ super().__init__(**kwargs)
55
+ self.drop_path = drop_path
56
+
57
+ def call(self, x: tf.Tensor, training=None):
58
+ if training:
59
+ keep_prob = 1 - self.drop_path
60
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
61
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
62
+ random_tensor = tf.floor(random_tensor)
63
+ return (x / keep_prob) * random_tensor
64
+ return x
65
+
66
+
67
+ class TFConvNextEmbeddings(keras.layers.Layer):
68
+ """This class is comparable to (and inspired by) the SwinEmbeddings class
69
+ found in src/transformers/models/swin/modeling_swin.py.
70
+ """
71
+
72
+ def __init__(self, config: ConvNextConfig, **kwargs):
73
+ super().__init__(**kwargs)
74
+ self.patch_embeddings = keras.layers.Conv2D(
75
+ filters=config.hidden_sizes[0],
76
+ kernel_size=config.patch_size,
77
+ strides=config.patch_size,
78
+ name="patch_embeddings",
79
+ kernel_initializer=get_initializer(config.initializer_range),
80
+ bias_initializer=keras.initializers.Zeros(),
81
+ )
82
+ self.layernorm = keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
83
+ self.num_channels = config.num_channels
84
+ self.config = config
85
+
86
+ def call(self, pixel_values):
87
+ if isinstance(pixel_values, dict):
88
+ pixel_values = pixel_values["pixel_values"]
89
+
90
+ tf.debugging.assert_equal(
91
+ shape_list(pixel_values)[1],
92
+ self.num_channels,
93
+ message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.",
94
+ )
95
+
96
+ # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
97
+ # So change the input format from `NCHW` to `NHWC`.
98
+ # shape = (batch_size, in_height, in_width, in_channels)
99
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
100
+
101
+ embeddings = self.patch_embeddings(pixel_values)
102
+ embeddings = self.layernorm(embeddings)
103
+ return embeddings
104
+
105
+ def build(self, input_shape=None):
106
+ if self.built:
107
+ return
108
+ self.built = True
109
+ if getattr(self, "patch_embeddings", None) is not None:
110
+ with tf.name_scope(self.patch_embeddings.name):
111
+ self.patch_embeddings.build([None, None, None, self.config.num_channels])
112
+ if getattr(self, "layernorm", None) is not None:
113
+ with tf.name_scope(self.layernorm.name):
114
+ self.layernorm.build([None, None, None, self.config.hidden_sizes[0]])
115
+
116
+
117
+ class TFConvNextLayer(keras.layers.Layer):
118
+ """This corresponds to the `Block` class in the original implementation.
119
+
120
+ There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
121
+ H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
122
+
123
+ The authors used (2) as they find it slightly faster in PyTorch. Since we already permuted the inputs to follow
124
+ NHWC ordering, we can just apply the operations straight-away without the permutation.
125
+
126
+ Args:
127
+ config ([`ConvNextConfig`]): Model configuration class.
128
+ dim (`int`): Number of input channels.
129
+ drop_path (`float`): Stochastic depth rate. Default: 0.0.
130
+ """
131
+
132
+ def __init__(self, config, dim, drop_path=0.0, **kwargs):
133
+ super().__init__(**kwargs)
134
+ self.dim = dim
135
+ self.config = config
136
+ self.dwconv = keras.layers.Conv2D(
137
+ filters=dim,
138
+ kernel_size=7,
139
+ padding="same",
140
+ groups=dim,
141
+ kernel_initializer=get_initializer(config.initializer_range),
142
+ bias_initializer="zeros",
143
+ name="dwconv",
144
+ ) # depthwise conv
145
+ self.layernorm = keras.layers.LayerNormalization(
146
+ epsilon=1e-6,
147
+ name="layernorm",
148
+ )
149
+ self.pwconv1 = keras.layers.Dense(
150
+ units=4 * dim,
151
+ kernel_initializer=get_initializer(config.initializer_range),
152
+ bias_initializer="zeros",
153
+ name="pwconv1",
154
+ ) # pointwise/1x1 convs, implemented with linear layers
155
+ self.act = get_tf_activation(config.hidden_act)
156
+ self.pwconv2 = keras.layers.Dense(
157
+ units=dim,
158
+ kernel_initializer=get_initializer(config.initializer_range),
159
+ bias_initializer="zeros",
160
+ name="pwconv2",
161
+ )
162
+ # Using `layers.Activation` instead of `tf.identity` to better control `training`
163
+ # behaviour.
164
+ self.drop_path = (
165
+ TFConvNextDropPath(drop_path, name="drop_path")
166
+ if drop_path > 0.0
167
+ else keras.layers.Activation("linear", name="drop_path")
168
+ )
169
+
170
+ def build(self, input_shape: tf.TensorShape = None):
171
+ # PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa)
172
+ self.layer_scale_parameter = (
173
+ self.add_weight(
174
+ shape=(self.dim,),
175
+ initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
176
+ trainable=True,
177
+ name="layer_scale_parameter",
178
+ )
179
+ if self.config.layer_scale_init_value > 0
180
+ else None
181
+ )
182
+
183
+ if self.built:
184
+ return
185
+ self.built = True
186
+ if getattr(self, "dwconv", None) is not None:
187
+ with tf.name_scope(self.dwconv.name):
188
+ self.dwconv.build([None, None, None, self.dim])
189
+ if getattr(self, "layernorm", None) is not None:
190
+ with tf.name_scope(self.layernorm.name):
191
+ self.layernorm.build([None, None, None, self.dim])
192
+ if getattr(self, "pwconv1", None) is not None:
193
+ with tf.name_scope(self.pwconv1.name):
194
+ self.pwconv1.build([None, None, self.dim])
195
+ if getattr(self, "pwconv2", None) is not None:
196
+ with tf.name_scope(self.pwconv2.name):
197
+ self.pwconv2.build([None, None, 4 * self.dim])
198
+ if getattr(self, "drop_path", None) is not None:
199
+ with tf.name_scope(self.drop_path.name):
200
+ self.drop_path.build(None)
201
+
202
+ def call(self, hidden_states, training=False):
203
+ input = hidden_states
204
+ x = self.dwconv(hidden_states)
205
+ x = self.layernorm(x)
206
+ x = self.pwconv1(x)
207
+ x = self.act(x)
208
+ x = self.pwconv2(x)
209
+
210
+ if self.layer_scale_parameter is not None:
211
+ x = self.layer_scale_parameter * x
212
+
213
+ x = input + self.drop_path(x, training=training)
214
+ return x
215
+
216
+
217
+ class TFConvNextStage(keras.layers.Layer):
218
+ """ConvNext stage, consisting of an optional downsampling layer + multiple residual blocks.
219
+
220
+ Args:
221
+ config (`ConvNextV2Config`):
222
+ Model configuration class.
223
+ in_channels (`int`):
224
+ Number of input channels.
225
+ out_channels (`int`):
226
+ Number of output channels.
227
+ depth (`int`):
228
+ Number of residual blocks.
229
+ drop_path_rates(`List[float]`):
230
+ Stochastic depth rates for each layer.
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ config: ConvNextConfig,
236
+ in_channels: int,
237
+ out_channels: int,
238
+ kernel_size: int = 2,
239
+ stride: int = 2,
240
+ depth: int = 2,
241
+ drop_path_rates: Optional[List[float]] = None,
242
+ **kwargs,
243
+ ):
244
+ super().__init__(**kwargs)
245
+ if in_channels != out_channels or stride > 1:
246
+ self.downsampling_layer = [
247
+ keras.layers.LayerNormalization(
248
+ epsilon=1e-6,
249
+ name="downsampling_layer.0",
250
+ ),
251
+ # Inputs to this layer will follow NHWC format since we
252
+ # transposed the inputs from NCHW to NHWC in the `TFConvNextEmbeddings`
253
+ # layer. All the outputs throughout the model will be in NHWC
254
+ # from this point on until the output where we again change to
255
+ # NCHW.
256
+ keras.layers.Conv2D(
257
+ filters=out_channels,
258
+ kernel_size=kernel_size,
259
+ strides=stride,
260
+ kernel_initializer=get_initializer(config.initializer_range),
261
+ bias_initializer=keras.initializers.Zeros(),
262
+ name="downsampling_layer.1",
263
+ ),
264
+ ]
265
+ else:
266
+ self.downsampling_layer = [tf.identity]
267
+
268
+ drop_path_rates = drop_path_rates or [0.0] * depth
269
+ self.layers = [
270
+ TFConvNextLayer(
271
+ config,
272
+ dim=out_channels,
273
+ drop_path=drop_path_rates[j],
274
+ name=f"layers.{j}",
275
+ )
276
+ for j in range(depth)
277
+ ]
278
+ self.in_channels = in_channels
279
+ self.out_channels = out_channels
280
+ self.stride = stride
281
+
282
+ def call(self, hidden_states):
283
+ for layer in self.downsampling_layer:
284
+ hidden_states = layer(hidden_states)
285
+ for layer in self.layers:
286
+ hidden_states = layer(hidden_states)
287
+ return hidden_states
288
+
289
+ def build(self, input_shape=None):
290
+ if self.built:
291
+ return
292
+ self.built = True
293
+ if getattr(self, "layers", None) is not None:
294
+ for layer in self.layers:
295
+ with tf.name_scope(layer.name):
296
+ layer.build(None)
297
+ if self.in_channels != self.out_channels or self.stride > 1:
298
+ with tf.name_scope(self.downsampling_layer[0].name):
299
+ self.downsampling_layer[0].build([None, None, None, self.in_channels])
300
+ with tf.name_scope(self.downsampling_layer[1].name):
301
+ self.downsampling_layer[1].build([None, None, None, self.in_channels])
302
+
303
+
304
+ class TFConvNextEncoder(keras.layers.Layer):
305
+ def __init__(self, config, **kwargs):
306
+ super().__init__(**kwargs)
307
+ self.stages = []
308
+ drop_path_rates = tf.linspace(0.0, config.drop_path_rate, sum(config.depths))
309
+ drop_path_rates = tf.split(drop_path_rates, config.depths)
310
+ drop_path_rates = [x.numpy().tolist() for x in drop_path_rates]
311
+ prev_chs = config.hidden_sizes[0]
312
+ for i in range(config.num_stages):
313
+ out_chs = config.hidden_sizes[i]
314
+ stage = TFConvNextStage(
315
+ config,
316
+ in_channels=prev_chs,
317
+ out_channels=out_chs,
318
+ stride=2 if i > 0 else 1,
319
+ depth=config.depths[i],
320
+ drop_path_rates=drop_path_rates[i],
321
+ name=f"stages.{i}",
322
+ )
323
+ self.stages.append(stage)
324
+ prev_chs = out_chs
325
+
326
+ def call(self, hidden_states, output_hidden_states=False, return_dict=True):
327
+ all_hidden_states = () if output_hidden_states else None
328
+
329
+ for i, layer_module in enumerate(self.stages):
330
+ if output_hidden_states:
331
+ all_hidden_states = all_hidden_states + (hidden_states,)
332
+
333
+ hidden_states = layer_module(hidden_states)
334
+
335
+ if output_hidden_states:
336
+ all_hidden_states = all_hidden_states + (hidden_states,)
337
+
338
+ if not return_dict:
339
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
340
+
341
+ return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
342
+
343
+ def build(self, input_shape=None):
344
+ for stage in self.stages:
345
+ with tf.name_scope(stage.name):
346
+ stage.build(None)
347
+
348
+
349
+ @keras_serializable
350
+ class TFConvNextMainLayer(keras.layers.Layer):
351
+ config_class = ConvNextConfig
352
+
353
+ def __init__(self, config: ConvNextConfig, add_pooling_layer: bool = True, **kwargs):
354
+ super().__init__(**kwargs)
355
+
356
+ self.config = config
357
+ self.embeddings = TFConvNextEmbeddings(config, name="embeddings")
358
+ self.encoder = TFConvNextEncoder(config, name="encoder")
359
+ self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
360
+ # We are setting the `data_format` like so because from here on we will revert to the
361
+ # NCHW output format
362
+ self.pooler = keras.layers.GlobalAvgPool2D(data_format="channels_first") if add_pooling_layer else None
363
+
364
+ @unpack_inputs
365
+ def call(
366
+ self,
367
+ pixel_values: TFModelInputType | None = None,
368
+ output_hidden_states: Optional[bool] = None,
369
+ return_dict: Optional[bool] = None,
370
+ training: bool = False,
371
+ ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
372
+ output_hidden_states = (
373
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
374
+ )
375
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
376
+
377
+ if pixel_values is None:
378
+ raise ValueError("You have to specify pixel_values")
379
+
380
+ embedding_output = self.embeddings(pixel_values, training=training)
381
+
382
+ encoder_outputs = self.encoder(
383
+ embedding_output,
384
+ output_hidden_states=output_hidden_states,
385
+ return_dict=return_dict,
386
+ training=training,
387
+ )
388
+
389
+ last_hidden_state = encoder_outputs[0]
390
+ # Change to NCHW output format have uniformity in the modules
391
+ last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
392
+ pooled_output = self.layernorm(self.pooler(last_hidden_state))
393
+
394
+ # Change the other hidden state outputs to NCHW as well
395
+ if output_hidden_states:
396
+ hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
397
+
398
+ if not return_dict:
399
+ hidden_states = hidden_states if output_hidden_states else ()
400
+ return (last_hidden_state, pooled_output) + hidden_states
401
+
402
+ return TFBaseModelOutputWithPooling(
403
+ last_hidden_state=last_hidden_state,
404
+ pooler_output=pooled_output,
405
+ hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
406
+ )
407
+
408
+ def build(self, input_shape=None):
409
+ if self.built:
410
+ return
411
+ self.built = True
412
+ if getattr(self, "embeddings", None) is not None:
413
+ with tf.name_scope(self.embeddings.name):
414
+ self.embeddings.build(None)
415
+ if getattr(self, "encoder", None) is not None:
416
+ with tf.name_scope(self.encoder.name):
417
+ self.encoder.build(None)
418
+ if getattr(self, "layernorm", None) is not None:
419
+ with tf.name_scope(self.layernorm.name):
420
+ self.layernorm.build([None, self.config.hidden_sizes[-1]])
421
+
422
+
423
+ class TFConvNextPreTrainedModel(TFPreTrainedModel):
424
+ """
425
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
426
+ models.
427
+ """
428
+
429
+ config_class = ConvNextConfig
430
+ base_model_prefix = "convnext"
431
+ main_input_name = "pixel_values"
432
+
433
+
434
+ CONVNEXT_START_DOCSTRING = r"""
435
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
436
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
437
+ etc.)
438
+
439
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
440
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
441
+ behavior.
442
+
443
+ <Tip>
444
+
445
+ TensorFlow models and layers in `transformers` accept two formats as input:
446
+
447
+ - having all inputs as keyword arguments (like PyTorch models), or
448
+ - having all inputs as a list, tuple or dict in the first positional argument.
449
+
450
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
451
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
452
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
453
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
454
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
455
+ positional argument:
456
+
457
+ - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
458
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
459
+ `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
460
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
461
+ `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
462
+
463
+ Note that when creating models and layers with
464
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
465
+ about any of this, as you can just pass inputs like you would to any other Python function!
466
+
467
+ </Tip>
468
+
469
+ Parameters:
470
+ config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model.
471
+ Initializing with a config file does not load the weights associated with the model, only the
472
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
473
+ """
474
+
475
+ CONVNEXT_INPUTS_DOCSTRING = r"""
476
+ Args:
477
+ pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
478
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
479
+ [`ConvNextImageProcessor.__call__`] for details.
480
+
481
+ output_hidden_states (`bool`, *optional*):
482
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
483
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
484
+ used instead.
485
+ return_dict (`bool`, *optional*):
486
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
487
+ eager mode, in graph mode the value will always be set to True.
488
+ """
489
+
490
+
491
+ @add_start_docstrings(
492
+ "The bare ConvNext model outputting raw features without any specific head on top.",
493
+ CONVNEXT_START_DOCSTRING,
494
+ )
495
+ class TFConvNextModel(TFConvNextPreTrainedModel):
496
+ def __init__(self, config, *inputs, add_pooling_layer=True, **kwargs):
497
+ super().__init__(config, *inputs, **kwargs)
498
+ self.convnext = TFConvNextMainLayer(config, add_pooling_layer=add_pooling_layer, name="convnext")
499
+
500
+ @unpack_inputs
501
+ @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
502
+ @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
503
+ def call(
504
+ self,
505
+ pixel_values: TFModelInputType | None = None,
506
+ output_hidden_states: Optional[bool] = None,
507
+ return_dict: Optional[bool] = None,
508
+ training: bool = False,
509
+ ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
510
+ r"""
511
+ Returns:
512
+
513
+ Examples:
514
+
515
+ ```python
516
+ >>> from transformers import AutoImageProcessor, TFConvNextModel
517
+ >>> from PIL import Image
518
+ >>> import requests
519
+
520
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
521
+ >>> image = Image.open(requests.get(url, stream=True).raw)
522
+
523
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
524
+ >>> model = TFConvNextModel.from_pretrained("facebook/convnext-tiny-224")
525
+
526
+ >>> inputs = image_processor(images=image, return_tensors="tf")
527
+ >>> outputs = model(**inputs)
528
+ >>> last_hidden_states = outputs.last_hidden_state
529
+ ```"""
530
+ output_hidden_states = (
531
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
532
+ )
533
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
534
+
535
+ if pixel_values is None:
536
+ raise ValueError("You have to specify pixel_values")
537
+
538
+ outputs = self.convnext(
539
+ pixel_values=pixel_values,
540
+ output_hidden_states=output_hidden_states,
541
+ return_dict=return_dict,
542
+ training=training,
543
+ )
544
+
545
+ if not return_dict:
546
+ return (outputs[0],) + outputs[1:]
547
+
548
+ return TFBaseModelOutputWithPooling(
549
+ last_hidden_state=outputs.last_hidden_state,
550
+ pooler_output=outputs.pooler_output,
551
+ hidden_states=outputs.hidden_states,
552
+ )
553
+
554
+ def build(self, input_shape=None):
555
+ if self.built:
556
+ return
557
+ self.built = True
558
+ if getattr(self, "convnext", None) is not None:
559
+ with tf.name_scope(self.convnext.name):
560
+ self.convnext.build(None)
561
+
562
+
563
+ @add_start_docstrings(
564
+ """
565
+ ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
566
+ ImageNet.
567
+ """,
568
+ CONVNEXT_START_DOCSTRING,
569
+ )
570
+ class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClassificationLoss):
571
+ def __init__(self, config: ConvNextConfig, *inputs, **kwargs):
572
+ super().__init__(config, *inputs, **kwargs)
573
+
574
+ self.num_labels = config.num_labels
575
+ self.convnext = TFConvNextMainLayer(config, name="convnext")
576
+
577
+ # Classifier head
578
+ self.classifier = keras.layers.Dense(
579
+ units=config.num_labels,
580
+ kernel_initializer=get_initializer(config.initializer_range),
581
+ bias_initializer="zeros",
582
+ name="classifier",
583
+ )
584
+ self.config = config
585
+
586
+ @unpack_inputs
587
+ @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
588
+ @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
589
+ def call(
590
+ self,
591
+ pixel_values: TFModelInputType | None = None,
592
+ output_hidden_states: Optional[bool] = None,
593
+ return_dict: Optional[bool] = None,
594
+ labels: np.ndarray | tf.Tensor | None = None,
595
+ training: Optional[bool] = False,
596
+ ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
597
+ r"""
598
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
599
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
600
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
601
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
602
+
603
+ Returns:
604
+
605
+ Examples:
606
+
607
+ ```python
608
+ >>> from transformers import AutoImageProcessor, TFConvNextForImageClassification
609
+ >>> import tensorflow as tf
610
+ >>> from PIL import Image
611
+ >>> import requests
612
+
613
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
614
+ >>> image = Image.open(requests.get(url, stream=True).raw)
615
+
616
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
617
+ >>> model = TFConvNextForImageClassification.from_pretrained("facebook/convnext-tiny-224")
618
+
619
+ >>> inputs = image_processor(images=image, return_tensors="tf")
620
+ >>> outputs = model(**inputs)
621
+ >>> logits = outputs.logits
622
+ >>> # model predicts one of the 1000 ImageNet classes
623
+ >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
624
+ >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
625
+ ```"""
626
+ output_hidden_states = (
627
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
628
+ )
629
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
630
+
631
+ if pixel_values is None:
632
+ raise ValueError("You have to specify pixel_values")
633
+
634
+ outputs = self.convnext(
635
+ pixel_values,
636
+ output_hidden_states=output_hidden_states,
637
+ return_dict=return_dict,
638
+ training=training,
639
+ )
640
+
641
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
642
+
643
+ logits = self.classifier(pooled_output)
644
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
645
+
646
+ if not return_dict:
647
+ output = (logits,) + outputs[2:]
648
+ return ((loss,) + output) if loss is not None else output
649
+
650
+ return TFSequenceClassifierOutput(
651
+ loss=loss,
652
+ logits=logits,
653
+ hidden_states=outputs.hidden_states,
654
+ )
655
+
656
+ def build(self, input_shape=None):
657
+ if self.built:
658
+ return
659
+ self.built = True
660
+ if getattr(self, "convnext", None) is not None:
661
+ with tf.name_scope(self.convnext.name):
662
+ self.convnext.build(None)
663
+ if getattr(self, "classifier", None) is not None:
664
+ if hasattr(self.classifier, "name"):
665
+ with tf.name_scope(self.classifier.name):
666
+ self.classifier.build([None, None, self.config.hidden_sizes[-1]])
667
+
668
+
669
+ __all__ = ["TFConvNextForImageClassification", "TFConvNextModel", "TFConvNextPreTrainedModel"]
.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_decision_transformer import *
22
+ from .modeling_decision_transformer import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (808 Bytes). View file
 
.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/configuration_decision_transformer.cpython-311.pyc ADDED
Binary file (6.8 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/__pycache__/modeling_decision_transformer.cpython-311.pyc ADDED
Binary file (47.1 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/decision_transformer/modeling_decision_transformer.py ADDED
@@ -0,0 +1,963 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Team The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DecisionTransformer model."""
16
+
17
+ import math
18
+ import os
19
+ from dataclasses import dataclass
20
+ from typing import Callable, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+
26
+ from ...activations import ACT2FN
27
+ from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
28
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
29
+ from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
30
+ from ...utils import (
31
+ ModelOutput,
32
+ add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ logging,
35
+ replace_return_docstrings,
36
+ )
37
+ from .configuration_decision_transformer import DecisionTransformerConfig
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CHECKPOINT_FOR_DOC = "edbeeching/decision-transformer-gym-hopper-medium"
43
+ _CONFIG_FOR_DOC = "DecisionTransformerConfig"
44
+
45
+
46
+ # Copied from transformers.models.gpt2.modeling_gpt2.load_tf_weights_in_gpt2
47
+ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
48
+ """Load tf checkpoints in a pytorch model"""
49
+ try:
50
+ import re
51
+
52
+ import tensorflow as tf
53
+ except ImportError:
54
+ logger.error(
55
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
56
+ "https://www.tensorflow.org/install/ for installation instructions."
57
+ )
58
+ raise
59
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
60
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
61
+ # Load weights from TF model
62
+ init_vars = tf.train.list_variables(tf_path)
63
+ names = []
64
+ arrays = []
65
+ for name, shape in init_vars:
66
+ logger.info(f"Loading TF weight {name} with shape {shape}")
67
+ array = tf.train.load_variable(tf_path, name)
68
+ names.append(name)
69
+ arrays.append(array.squeeze())
70
+
71
+ for name, array in zip(names, arrays):
72
+ name = name[6:] # skip "model/"
73
+ name = name.split("/")
74
+ pointer = model
75
+ for m_name in name:
76
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
77
+ scope_names = re.split(r"(\d+)", m_name)
78
+ else:
79
+ scope_names = [m_name]
80
+ if scope_names[0] == "w" or scope_names[0] == "g":
81
+ pointer = getattr(pointer, "weight")
82
+ elif scope_names[0] == "b":
83
+ pointer = getattr(pointer, "bias")
84
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
85
+ pointer = getattr(pointer, scope_names[0])
86
+ pointer = getattr(pointer, "weight")
87
+ else:
88
+ pointer = getattr(pointer, scope_names[0])
89
+ if len(scope_names) >= 2:
90
+ num = int(scope_names[1])
91
+ pointer = pointer[num]
92
+ try:
93
+ if pointer.shape != array.shape:
94
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
95
+ except ValueError as e:
96
+ e.args += (pointer.shape, array.shape)
97
+ raise
98
+ logger.info(f"Initialize PyTorch weight {name}")
99
+ pointer.data = torch.from_numpy(array)
100
+ return model
101
+
102
+
103
+ # Copied from transformers.models.gpt2.modeling_gpt2.eager_attention_forward
104
+ def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs):
105
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
106
+
107
+ if module.scale_attn_weights:
108
+ attn_weights = attn_weights / torch.full(
109
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
110
+ )
111
+
112
+ # Layer-wise attention scaling
113
+ if module.scale_attn_by_inverse_layer_idx:
114
+ attn_weights = attn_weights / float(module.layer_idx + 1)
115
+
116
+ if not module.is_cross_attention:
117
+ # if only "normal" attention layer implements causal mask
118
+ query_length, key_length = query.size(-2), key.size(-2)
119
+ causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
120
+ mask_value = torch.finfo(attn_weights.dtype).min
121
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
122
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
123
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
124
+ attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
125
+
126
+ if attention_mask is not None:
127
+ # Apply the attention mask
128
+ attn_weights = attn_weights + attention_mask
129
+
130
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
131
+
132
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
133
+ attn_weights = attn_weights.type(value.dtype)
134
+ attn_weights = module.attn_dropout(attn_weights)
135
+
136
+ # Mask heads if we want to
137
+ if head_mask is not None:
138
+ attn_weights = attn_weights * head_mask
139
+
140
+ attn_output = torch.matmul(attn_weights, value)
141
+ attn_output = attn_output.transpose(1, 2)
142
+
143
+ return attn_output, attn_weights
144
+
145
+
146
+ # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->DecisionTransformerGPT2
147
+ class DecisionTransformerGPT2Attention(nn.Module):
148
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
149
+ super().__init__()
150
+ self.config = config
151
+ max_positions = config.max_position_embeddings
152
+ self.register_buffer(
153
+ "bias",
154
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
155
+ 1, 1, max_positions, max_positions
156
+ ),
157
+ persistent=False,
158
+ )
159
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
160
+
161
+ self.embed_dim = config.hidden_size
162
+ self.num_heads = config.num_attention_heads
163
+ self.head_dim = self.embed_dim // self.num_heads
164
+ self.split_size = self.embed_dim
165
+ if self.head_dim * self.num_heads != self.embed_dim:
166
+ raise ValueError(
167
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
168
+ f" {self.num_heads})."
169
+ )
170
+
171
+ self.scale_attn_weights = config.scale_attn_weights
172
+ self.is_cross_attention = is_cross_attention
173
+
174
+ # Layer-wise attention scaling, reordering, and upcasting
175
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
176
+ self.layer_idx = layer_idx
177
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
178
+
179
+ if self.is_cross_attention:
180
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
181
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
182
+ else:
183
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
184
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
185
+
186
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
187
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
188
+ self.is_causal = True
189
+
190
+ self.pruned_heads = set()
191
+
192
+ def prune_heads(self, heads):
193
+ if len(heads) == 0:
194
+ return
195
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
196
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
197
+
198
+ # Prune conv1d layers
199
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
200
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
201
+
202
+ # Update hyper params
203
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
204
+ self.num_heads = self.num_heads - len(heads)
205
+ self.pruned_heads = self.pruned_heads.union(heads)
206
+
207
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
208
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
209
+ bsz, num_heads, q_seq_len, dk = query.size()
210
+ _, _, k_seq_len, _ = key.size()
211
+
212
+ # Preallocate attn_weights for `baddbmm`
213
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
214
+
215
+ # Compute Scale Factor
216
+ scale_factor = 1.0
217
+ if self.scale_attn_weights:
218
+ scale_factor /= float(value.size(-1)) ** 0.5
219
+
220
+ if self.scale_attn_by_inverse_layer_idx:
221
+ scale_factor /= float(self.layer_idx + 1)
222
+
223
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
224
+ with torch.amp.autocast(query.device.type, enabled=False):
225
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
226
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
227
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
228
+
229
+ if not self.is_cross_attention:
230
+ # if only "normal" attention layer implements causal mask
231
+ query_length, key_length = query.size(-2), key.size(-2)
232
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
233
+ mask_value = torch.finfo(attn_weights.dtype).min
234
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
235
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
236
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
237
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
238
+
239
+ if attention_mask is not None:
240
+ # Apply the attention mask
241
+ attn_weights = attn_weights + attention_mask
242
+
243
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
244
+
245
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
246
+ if attn_weights.dtype != torch.float32:
247
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
248
+ attn_weights = attn_weights.type(value.dtype)
249
+ attn_weights = self.attn_dropout(attn_weights)
250
+
251
+ # Mask heads if we want to
252
+ if head_mask is not None:
253
+ attn_weights = attn_weights * head_mask
254
+
255
+ attn_output = torch.matmul(attn_weights, value)
256
+ attn_output = attn_output.transpose(1, 2)
257
+
258
+ return attn_output, attn_weights
259
+
260
+ def forward(
261
+ self,
262
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
263
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
264
+ attention_mask: Optional[torch.FloatTensor] = None,
265
+ head_mask: Optional[torch.FloatTensor] = None,
266
+ encoder_hidden_states: Optional[torch.Tensor] = None,
267
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
268
+ use_cache: Optional[bool] = False,
269
+ output_attentions: Optional[bool] = False,
270
+ **kwargs,
271
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
272
+ if encoder_hidden_states is not None:
273
+ if not hasattr(self, "q_attn"):
274
+ raise ValueError(
275
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
276
+ "Please make sure to instantiate class with `DecisionTransformerGPT2Attention(..., is_cross_attention=True)`."
277
+ )
278
+
279
+ query_states = self.q_attn(hidden_states)
280
+ key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
281
+ attention_mask = encoder_attention_mask
282
+ else:
283
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
284
+
285
+ shape_q = (*query_states.shape[:-1], -1, self.head_dim)
286
+ shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
287
+
288
+ query_states = query_states.view(shape_q).transpose(1, 2)
289
+ key_states = key_states.view(shape_kv).transpose(1, 2)
290
+ value_states = value_states.view(shape_kv).transpose(1, 2)
291
+
292
+ if layer_past is not None:
293
+ past_key, past_value = layer_past
294
+ key_states = torch.cat((past_key, key_states), dim=-2)
295
+ value_states = torch.cat((past_value, value_states), dim=-2)
296
+
297
+ if use_cache is True:
298
+ present = (key_states, value_states)
299
+ else:
300
+ present = None
301
+
302
+ is_cross_attention = encoder_hidden_states is not None
303
+ is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
304
+
305
+ using_eager = self.config._attn_implementation == "eager"
306
+ attention_interface: Callable = eager_attention_forward
307
+ if self.config._attn_implementation != "eager":
308
+ if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
309
+ using_eager = True
310
+ logger.warning_once(
311
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
312
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
313
+ )
314
+ else:
315
+ # Attention functions are consistent with previous equivalent attention classes, however they do not support some options
316
+ # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
317
+ # not necessarily to eager (if mentionned options are provided).
318
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
319
+
320
+ if using_eager and self.reorder_and_upcast_attn:
321
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
322
+ query_states, key_states, value_states, attention_mask, head_mask
323
+ )
324
+ else:
325
+ attn_output, attn_weights = attention_interface(
326
+ self,
327
+ query_states,
328
+ key_states,
329
+ value_states,
330
+ attention_mask,
331
+ head_mask=head_mask,
332
+ dropout=self.attn_dropout.p if self.training else 0.0,
333
+ is_causal=is_causal,
334
+ **kwargs,
335
+ )
336
+
337
+ attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
338
+ attn_output = self.c_proj(attn_output)
339
+ attn_output = self.resid_dropout(attn_output)
340
+
341
+ outputs = (attn_output, present)
342
+ if output_attentions:
343
+ outputs += (attn_weights,)
344
+
345
+ return outputs # a, present, (attentions)
346
+
347
+
348
+ # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->DecisionTransformerGPT2
349
+ class DecisionTransformerGPT2MLP(nn.Module):
350
+ def __init__(self, intermediate_size, config):
351
+ super().__init__()
352
+ embed_dim = config.hidden_size
353
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
354
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
355
+ self.act = ACT2FN[config.activation_function]
356
+ self.dropout = nn.Dropout(config.resid_pdrop)
357
+
358
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
359
+ hidden_states = self.c_fc(hidden_states)
360
+ hidden_states = self.act(hidden_states)
361
+ hidden_states = self.c_proj(hidden_states)
362
+ hidden_states = self.dropout(hidden_states)
363
+ return hidden_states
364
+
365
+
366
+ # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2
367
+ class DecisionTransformerGPT2Block(nn.Module):
368
+ # Ignore copy
369
+ def __init__(self, config, layer_idx=None):
370
+ super().__init__()
371
+ hidden_size = config.hidden_size
372
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
373
+
374
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
375
+ self.attn = DecisionTransformerGPT2Attention(config, layer_idx=layer_idx)
376
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
377
+
378
+ if config.add_cross_attention:
379
+ self.crossattention = DecisionTransformerGPT2Attention(
380
+ config, is_cross_attention=True, layer_idx=layer_idx
381
+ )
382
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
383
+
384
+ self.mlp = DecisionTransformerGPT2MLP(inner_dim, config)
385
+
386
+ def forward(
387
+ self,
388
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
389
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
390
+ attention_mask: Optional[torch.FloatTensor] = None,
391
+ head_mask: Optional[torch.FloatTensor] = None,
392
+ encoder_hidden_states: Optional[torch.Tensor] = None,
393
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
394
+ use_cache: Optional[bool] = False,
395
+ output_attentions: Optional[bool] = False,
396
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
397
+ residual = hidden_states
398
+ hidden_states = self.ln_1(hidden_states)
399
+ attn_outputs = self.attn(
400
+ hidden_states,
401
+ layer_past=layer_past,
402
+ attention_mask=attention_mask,
403
+ head_mask=head_mask,
404
+ use_cache=use_cache,
405
+ output_attentions=output_attentions,
406
+ )
407
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
408
+ outputs = attn_outputs[1:]
409
+ # residual connection
410
+ hidden_states = attn_output + residual
411
+
412
+ if encoder_hidden_states is not None:
413
+ # add one self-attention block for cross-attention
414
+ if not hasattr(self, "crossattention"):
415
+ raise ValueError(
416
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
417
+ "cross-attention layers by setting `config.add_cross_attention=True`"
418
+ )
419
+ residual = hidden_states
420
+ hidden_states = self.ln_cross_attn(hidden_states)
421
+ cross_attn_outputs = self.crossattention(
422
+ hidden_states,
423
+ attention_mask=attention_mask,
424
+ head_mask=head_mask,
425
+ encoder_hidden_states=encoder_hidden_states,
426
+ encoder_attention_mask=encoder_attention_mask,
427
+ output_attentions=output_attentions,
428
+ )
429
+ attn_output = cross_attn_outputs[0]
430
+ # residual connection
431
+ hidden_states = residual + attn_output
432
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
433
+
434
+ residual = hidden_states
435
+ hidden_states = self.ln_2(hidden_states)
436
+ feed_forward_hidden_states = self.mlp(hidden_states)
437
+ # residual connection
438
+ hidden_states = residual + feed_forward_hidden_states
439
+
440
+ if use_cache:
441
+ outputs = (hidden_states,) + outputs
442
+ else:
443
+ outputs = (hidden_states,) + outputs[1:]
444
+
445
+ return outputs # hidden_states, present, (attentions, cross_attentions)
446
+
447
+
448
+ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
449
+ """
450
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
451
+ models.
452
+ """
453
+
454
+ config_class = DecisionTransformerConfig
455
+ load_tf_weights = load_tf_weights_in_gpt2
456
+ base_model_prefix = "transformer"
457
+ is_parallelizable = True
458
+ supports_gradient_checkpointing = True
459
+
460
+ def __init__(self, *inputs, **kwargs):
461
+ super().__init__(*inputs, **kwargs)
462
+
463
+ def _init_weights(self, module):
464
+ """Initialize the weights."""
465
+ if isinstance(module, (nn.Linear, Conv1D)):
466
+ # Slightly different from the TF version which uses truncated_normal for initialization
467
+ # cf https://github.com/pytorch/pytorch/pull/5617
468
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
469
+ if module.bias is not None:
470
+ module.bias.data.zero_()
471
+ elif isinstance(module, nn.Embedding):
472
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
473
+ if module.padding_idx is not None:
474
+ module.weight.data[module.padding_idx].zero_()
475
+ elif isinstance(module, nn.LayerNorm):
476
+ module.bias.data.zero_()
477
+ module.weight.data.fill_(1.0)
478
+
479
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
480
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
481
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
482
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
483
+ #
484
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
485
+ for name, p in module.named_parameters():
486
+ if "c_proj" in name and "weight" in name:
487
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
488
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
489
+
490
+
491
+ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
492
+ def __init__(self, config):
493
+ super().__init__(config)
494
+
495
+ self.embed_dim = config.hidden_size
496
+
497
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
498
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
499
+
500
+ self.drop = nn.Dropout(config.embd_pdrop)
501
+ self.h = nn.ModuleList(
502
+ [DecisionTransformerGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
503
+ )
504
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
505
+
506
+ # Model parallel
507
+ self.model_parallel = False
508
+ self.device_map = None
509
+ self.gradient_checkpointing = False
510
+
511
+ # Initialize weights and apply final processing
512
+ self.post_init()
513
+
514
+ def get_input_embeddings(self):
515
+ return self.wte
516
+
517
+ def set_input_embeddings(self, new_embeddings):
518
+ self.wte = new_embeddings
519
+
520
+ def forward(
521
+ self,
522
+ input_ids: Optional[torch.LongTensor] = None,
523
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
524
+ attention_mask: Optional[torch.FloatTensor] = None,
525
+ token_type_ids: Optional[torch.LongTensor] = None,
526
+ position_ids: Optional[torch.LongTensor] = None,
527
+ head_mask: Optional[torch.FloatTensor] = None,
528
+ inputs_embeds: Optional[torch.FloatTensor] = None,
529
+ encoder_hidden_states: Optional[torch.Tensor] = None,
530
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
531
+ use_cache: Optional[bool] = None,
532
+ output_attentions: Optional[bool] = None,
533
+ output_hidden_states: Optional[bool] = None,
534
+ return_dict: Optional[bool] = None,
535
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
536
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
537
+ output_hidden_states = (
538
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
539
+ )
540
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
541
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
542
+
543
+ if input_ids is not None and inputs_embeds is not None:
544
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
545
+ elif input_ids is not None:
546
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
547
+ input_shape = input_ids.size()
548
+ input_ids = input_ids.view(-1, input_shape[-1])
549
+ batch_size = input_ids.shape[0]
550
+ elif inputs_embeds is not None:
551
+ input_shape = inputs_embeds.size()[:-1]
552
+ batch_size = inputs_embeds.shape[0]
553
+ else:
554
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
555
+
556
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
557
+
558
+ if token_type_ids is not None:
559
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
560
+
561
+ if past_key_values is None:
562
+ past_length = 0
563
+ past_key_values = tuple([None] * len(self.h))
564
+ else:
565
+ past_length = past_key_values[0][0].size(-2)
566
+ if position_ids is None:
567
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
568
+ position_ids = position_ids.unsqueeze(0)
569
+
570
+ # Attention mask.
571
+ if attention_mask is not None:
572
+ if batch_size <= 0:
573
+ raise ValueError("batch_size has to be defined and > 0")
574
+ attention_mask = attention_mask.view(batch_size, -1)
575
+ # We create a 3D attention mask from a 2D tensor mask.
576
+ # Sizes are [batch_size, 1, 1, to_seq_length]
577
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
578
+ # this attention mask is more simple than the triangular masking of causal attention
579
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
580
+ attention_mask = attention_mask[:, None, None, :]
581
+
582
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
583
+ # masked positions, this operation will create a tensor which is 0.0 for
584
+ # positions we want to attend and the dtype's smallest value for masked positions.
585
+ # Since we are adding it to the raw scores before the softmax, this is
586
+ # effectively the same as removing these entirely.
587
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
588
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
589
+
590
+ # If a 2D or 3D attention mask is provided for the cross-attention
591
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
592
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
593
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
594
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
595
+ if encoder_attention_mask is None:
596
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
597
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
598
+ else:
599
+ encoder_attention_mask = None
600
+
601
+ # Prepare head mask if needed
602
+ # 1.0 in head_mask indicate we keep the head
603
+ # attention_probs has shape bsz x n_heads x N x N
604
+ # head_mask has shape n_layer x batch x n_heads x N x N
605
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
606
+
607
+ if inputs_embeds is None:
608
+ inputs_embeds = self.wte(input_ids)
609
+ position_embeds = self.wpe(position_ids)
610
+ hidden_states = inputs_embeds + position_embeds
611
+
612
+ if token_type_ids is not None:
613
+ token_type_embeds = self.wte(token_type_ids)
614
+ hidden_states = hidden_states + token_type_embeds
615
+
616
+ hidden_states = self.drop(hidden_states)
617
+
618
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
619
+
620
+ if self.gradient_checkpointing and self.training:
621
+ if use_cache:
622
+ logger.warning_once(
623
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
624
+ )
625
+ use_cache = False
626
+
627
+ presents = () if use_cache else None
628
+ all_self_attentions = () if output_attentions else None
629
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
630
+ all_hidden_states = () if output_hidden_states else None
631
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
632
+ # Model parallel
633
+ if self.model_parallel:
634
+ torch.cuda.set_device(hidden_states.device)
635
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
636
+ if layer_past is not None:
637
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
638
+ # Ensure that attention_mask is always on the same device as hidden_states
639
+ if attention_mask is not None:
640
+ attention_mask = attention_mask.to(hidden_states.device)
641
+ if isinstance(head_mask, torch.Tensor):
642
+ head_mask = head_mask.to(hidden_states.device)
643
+ if output_hidden_states:
644
+ all_hidden_states = all_hidden_states + (hidden_states,)
645
+
646
+ if self.gradient_checkpointing and self.training:
647
+ outputs = self._gradient_checkpointing_func(
648
+ block.__call__,
649
+ hidden_states,
650
+ None,
651
+ attention_mask,
652
+ head_mask[i],
653
+ encoder_hidden_states,
654
+ encoder_attention_mask,
655
+ use_cache,
656
+ output_attentions,
657
+ )
658
+ else:
659
+ outputs = block(
660
+ hidden_states,
661
+ layer_past=layer_past,
662
+ attention_mask=attention_mask,
663
+ head_mask=head_mask[i],
664
+ encoder_hidden_states=encoder_hidden_states,
665
+ encoder_attention_mask=encoder_attention_mask,
666
+ use_cache=use_cache,
667
+ output_attentions=output_attentions,
668
+ )
669
+
670
+ hidden_states = outputs[0]
671
+ if use_cache is True:
672
+ presents = presents + (outputs[1],)
673
+
674
+ if output_attentions:
675
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
676
+ if self.config.add_cross_attention:
677
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
678
+
679
+ # Model Parallel: If it's the last layer for that device, put things on the next device
680
+ if self.model_parallel:
681
+ for k, v in self.device_map.items():
682
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
683
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
684
+
685
+ hidden_states = self.ln_f(hidden_states)
686
+
687
+ hidden_states = hidden_states.view(output_shape)
688
+ # Add last hidden state
689
+ if output_hidden_states:
690
+ all_hidden_states = all_hidden_states + (hidden_states,)
691
+
692
+ if not return_dict:
693
+ return tuple(
694
+ v
695
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
696
+ if v is not None
697
+ )
698
+
699
+ return BaseModelOutputWithPastAndCrossAttentions(
700
+ last_hidden_state=hidden_states,
701
+ past_key_values=presents,
702
+ hidden_states=all_hidden_states,
703
+ attentions=all_self_attentions,
704
+ cross_attentions=all_cross_attentions,
705
+ )
706
+
707
+
708
+ @dataclass
709
+ class DecisionTransformerOutput(ModelOutput):
710
+ """
711
+ Base class for model's outputs that also contains a pooling of the last hidden states.
712
+
713
+ Args:
714
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
715
+ Sequence of hidden-states at the output of the last layer of the model.
716
+ state_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, state_dim)`):
717
+ Environment state predictions
718
+ action_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, action_dim)`):
719
+ Model action predictions
720
+ return_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, 1)`):
721
+ Predicted returns for each state
722
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
723
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
724
+ shape `(batch_size, sequence_length, hidden_size)`.
725
+
726
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
727
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
728
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
729
+ sequence_length)`.
730
+
731
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
732
+ heads.
733
+ """
734
+
735
+ state_preds: torch.FloatTensor = None
736
+ action_preds: torch.FloatTensor = None
737
+ return_preds: torch.FloatTensor = None
738
+ hidden_states: torch.FloatTensor = None
739
+ attentions: torch.FloatTensor = None
740
+ last_hidden_state: torch.FloatTensor = None
741
+
742
+
743
+ class DecisionTransformerPreTrainedModel(PreTrainedModel):
744
+ """
745
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
746
+ models.
747
+ """
748
+
749
+ config_class = DecisionTransformerConfig
750
+ base_model_prefix = "decision_transformer"
751
+ main_input_name = "states"
752
+ supports_gradient_checkpointing = False
753
+
754
+ def _init_weights(self, module):
755
+ """Initialize the weights"""
756
+ if isinstance(module, nn.Linear):
757
+ # Slightly different from the TF version which uses truncated_normal for initialization
758
+ # cf https://github.com/pytorch/pytorch/pull/5617
759
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
760
+ if module.bias is not None:
761
+ module.bias.data.zero_()
762
+ elif isinstance(module, nn.Embedding):
763
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
764
+ if module.padding_idx is not None:
765
+ module.weight.data[module.padding_idx].zero_()
766
+ elif isinstance(module, nn.LayerNorm):
767
+ module.bias.data.zero_()
768
+ module.weight.data.fill_(1.0)
769
+
770
+
771
+ DECISION_TRANSFORMER_START_DOCSTRING = r"""
772
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
773
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
774
+ behavior.
775
+
776
+ Parameters:
777
+ config ([`~DecisionTransformerConfig`]): Model configuration class with all the parameters of the model.
778
+ Initializing with a config file does not load the weights associated with the model, only the
779
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
780
+ """
781
+
782
+ DECISION_TRANSFORMER_INPUTS_DOCSTRING = r"""
783
+ Args:
784
+ states (`torch.FloatTensor` of shape `(batch_size, episode_length, state_dim)`):
785
+ The states for each step in the trajectory
786
+ actions (`torch.FloatTensor` of shape `(batch_size, episode_length, act_dim)`):
787
+ The actions taken by the "expert" policy for the current state, these are masked for auto regressive
788
+ prediction
789
+ rewards (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
790
+ The rewards for each state, action
791
+ returns_to_go (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
792
+ The returns for each state in the trajectory
793
+ timesteps (`torch.LongTensor` of shape `(batch_size, episode_length)`):
794
+ The timestep for each step in the trajectory
795
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, episode_length)`):
796
+ Masking, used to mask the actions when performing autoregressive prediction
797
+ """
798
+
799
+
800
+ @add_start_docstrings("The Decision Transformer Model", DECISION_TRANSFORMER_START_DOCSTRING)
801
+ class DecisionTransformerModel(DecisionTransformerPreTrainedModel):
802
+ """
803
+
804
+ The model builds upon the GPT2 architecture to perform autoregressive prediction of actions in an offline RL
805
+ setting. Refer to the paper for more details: https://arxiv.org/abs/2106.01345
806
+
807
+ """
808
+
809
+ def __init__(self, config):
810
+ super().__init__(config)
811
+ self.config = config
812
+ self.hidden_size = config.hidden_size
813
+ # note: the only difference between this GPT2Model and the default Huggingface version
814
+ # is that the positional embeddings are removed (since we'll add those ourselves)
815
+ self.encoder = DecisionTransformerGPT2Model(config)
816
+
817
+ self.embed_timestep = nn.Embedding(config.max_ep_len, config.hidden_size)
818
+ self.embed_return = torch.nn.Linear(1, config.hidden_size)
819
+ self.embed_state = torch.nn.Linear(config.state_dim, config.hidden_size)
820
+ self.embed_action = torch.nn.Linear(config.act_dim, config.hidden_size)
821
+
822
+ self.embed_ln = nn.LayerNorm(config.hidden_size)
823
+
824
+ # note: we don't predict states or returns for the paper
825
+ self.predict_state = torch.nn.Linear(config.hidden_size, config.state_dim)
826
+ self.predict_action = nn.Sequential(
827
+ *([nn.Linear(config.hidden_size, config.act_dim)] + ([nn.Tanh()] if config.action_tanh else []))
828
+ )
829
+ self.predict_return = torch.nn.Linear(config.hidden_size, 1)
830
+
831
+ # Initialize weights and apply final processing
832
+ self.post_init()
833
+
834
+ @add_start_docstrings_to_model_forward(DECISION_TRANSFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
835
+ @replace_return_docstrings(output_type=DecisionTransformerOutput, config_class=_CONFIG_FOR_DOC)
836
+ def forward(
837
+ self,
838
+ states: Optional[torch.FloatTensor] = None,
839
+ actions: Optional[torch.FloatTensor] = None,
840
+ rewards: Optional[torch.FloatTensor] = None,
841
+ returns_to_go: Optional[torch.FloatTensor] = None,
842
+ timesteps: Optional[torch.LongTensor] = None,
843
+ attention_mask: Optional[torch.FloatTensor] = None,
844
+ output_hidden_states: Optional[bool] = None,
845
+ output_attentions: Optional[bool] = None,
846
+ return_dict: Optional[bool] = None,
847
+ ) -> Union[Tuple[torch.FloatTensor], DecisionTransformerOutput]:
848
+ r"""
849
+ Returns:
850
+
851
+ Examples:
852
+
853
+ ```python
854
+ >>> from transformers import DecisionTransformerModel
855
+ >>> import torch
856
+
857
+ >>> model = DecisionTransformerModel.from_pretrained("edbeeching/decision-transformer-gym-hopper-medium")
858
+ >>> # evaluation
859
+ >>> model = model.to(device)
860
+ >>> model.eval()
861
+
862
+ >>> env = gym.make("Hopper-v3")
863
+ >>> state_dim = env.observation_space.shape[0]
864
+ >>> act_dim = env.action_space.shape[0]
865
+
866
+ >>> state = env.reset()
867
+ >>> states = torch.from_numpy(state).reshape(1, 1, state_dim).to(device=device, dtype=torch.float32)
868
+ >>> actions = torch.zeros((1, 1, act_dim), device=device, dtype=torch.float32)
869
+ >>> rewards = torch.zeros(1, 1, device=device, dtype=torch.float32)
870
+ >>> target_return = torch.tensor(TARGET_RETURN, dtype=torch.float32).reshape(1, 1)
871
+ >>> timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)
872
+ >>> attention_mask = torch.zeros(1, 1, device=device, dtype=torch.float32)
873
+
874
+ >>> # forward pass
875
+ >>> with torch.no_grad():
876
+ ... state_preds, action_preds, return_preds = model(
877
+ ... states=states,
878
+ ... actions=actions,
879
+ ... rewards=rewards,
880
+ ... returns_to_go=target_return,
881
+ ... timesteps=timesteps,
882
+ ... attention_mask=attention_mask,
883
+ ... return_dict=False,
884
+ ... )
885
+ ```"""
886
+
887
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
888
+ output_hidden_states = (
889
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
890
+ )
891
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
892
+
893
+ batch_size, seq_length = states.shape[0], states.shape[1]
894
+
895
+ if attention_mask is None:
896
+ # attention mask for GPT: 1 if can be attended to, 0 if not
897
+ attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
898
+
899
+ # embed each modality with a different head
900
+ state_embeddings = self.embed_state(states)
901
+ action_embeddings = self.embed_action(actions)
902
+ returns_embeddings = self.embed_return(returns_to_go)
903
+ time_embeddings = self.embed_timestep(timesteps)
904
+
905
+ # time embeddings are treated similar to positional embeddings
906
+ state_embeddings = state_embeddings + time_embeddings
907
+ action_embeddings = action_embeddings + time_embeddings
908
+ returns_embeddings = returns_embeddings + time_embeddings
909
+
910
+ # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
911
+ # which works nice in an autoregressive sense since states predict actions
912
+ stacked_inputs = (
913
+ torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1)
914
+ .permute(0, 2, 1, 3)
915
+ .reshape(batch_size, 3 * seq_length, self.hidden_size)
916
+ )
917
+ stacked_inputs = self.embed_ln(stacked_inputs)
918
+
919
+ # to make the attention mask fit the stacked inputs, have to stack it as well
920
+ stacked_attention_mask = (
921
+ torch.stack((attention_mask, attention_mask, attention_mask), dim=1)
922
+ .permute(0, 2, 1)
923
+ .reshape(batch_size, 3 * seq_length)
924
+ )
925
+ device = stacked_inputs.device
926
+ # we feed in the input embeddings (not word indices as in NLP) to the model
927
+ encoder_outputs = self.encoder(
928
+ inputs_embeds=stacked_inputs,
929
+ attention_mask=stacked_attention_mask,
930
+ position_ids=torch.zeros(stacked_attention_mask.shape, device=device, dtype=torch.long),
931
+ output_attentions=output_attentions,
932
+ output_hidden_states=output_hidden_states,
933
+ return_dict=return_dict,
934
+ )
935
+ x = encoder_outputs[0]
936
+
937
+ # reshape x so that the second dimension corresponds to the original
938
+ # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
939
+ x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)
940
+
941
+ # get predictions
942
+ return_preds = self.predict_return(x[:, 2]) # predict next return given state and action
943
+ state_preds = self.predict_state(x[:, 2]) # predict next state given state and action
944
+ action_preds = self.predict_action(x[:, 1]) # predict next action given state
945
+ if not return_dict:
946
+ return (state_preds, action_preds, return_preds)
947
+
948
+ return DecisionTransformerOutput(
949
+ last_hidden_state=encoder_outputs.last_hidden_state,
950
+ state_preds=state_preds,
951
+ action_preds=action_preds,
952
+ return_preds=return_preds,
953
+ hidden_states=encoder_outputs.hidden_states,
954
+ attentions=encoder_outputs.attentions,
955
+ )
956
+
957
+
958
+ __all__ = [
959
+ "DecisionTransformerGPT2Model",
960
+ "DecisionTransformerGPT2PreTrainedModel",
961
+ "DecisionTransformerModel",
962
+ "DecisionTransformerPreTrainedModel",
963
+ ]
.venv/lib/python3.11/site-packages/transformers/models/focalnet/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (772 Bytes). View file
 
.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .tokenization_gpt_sw3 import *
22
+ else:
23
+ import sys
24
+
25
+ _file = globals()["__file__"]
26
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (730 Bytes). View file
 
.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/__pycache__/tokenization_gpt_sw3.cpython-311.pyc ADDED
Binary file (16 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/gpt_sw3/tokenization_gpt_sw3.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The tokenizer used by the GPT-SW3 models."""
2
+
3
+ import os
4
+ import re
5
+ import unicodedata
6
+ from shutil import copyfile
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
+
9
+ import sentencepiece as spm
10
+
11
+ from ...tokenization_utils import PreTrainedTokenizer
12
+ from ...utils import is_torch_available, logging
13
+
14
+
15
+ if is_torch_available():
16
+ import torch
17
+
18
+
19
+ logger = logging.get_logger(__name__)
20
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
21
+
22
+
23
+ class GPTSw3Tokenizer(PreTrainedTokenizer):
24
+ """
25
+ Construct an GPTSw3 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
26
+
27
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
28
+ this superclass for more information regarding those methods.
29
+
30
+ Example usage:
31
+ ```python
32
+ >>> from transformers import GPTSw3Tokenizer
33
+
34
+ >>> tokenizer = GPTSw3Tokenizer.from_pretrained("AI-Sweden-Models/gpt-sw3-126m")
35
+ >>> tokenizer("Svenska är kul!")["input_ids"]
36
+ [1814, 377, 3617, 63504]
37
+ ```
38
+
39
+ Args:
40
+ vocab_file (`str`):
41
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
42
+ contains the vocabulary necessary to instantiate a tokenizer.
43
+ do_lower_case (`bool`, *optional*, defaults to `False`):
44
+ Whether or not to lowercase the input when tokenizing.
45
+ remove_space (`bool`, *optional*, defaults to `False`):
46
+ Whether or not to strip the text when tokenizing (removing excess spaces before and after the string).
47
+ keep_accents (`bool`, *optional*, defaults to `False`):
48
+ Whether or not to keep accents when tokenizing.
49
+ pad_token (`str`, *optional*):
50
+ The token used for padding, for example when batching sequences of different lengths. If not provided, will
51
+ default to '<pad>' or '<unk>' depending on model size.
52
+ unk_token (`str`, *optional*):
53
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
54
+ token instead. If not provided, will default to '<unk>'.
55
+ eos_token (`str`, *optional*):
56
+ The end of sequence token seen during pretraining. If not provided, will default to '<|endoftext|>'
57
+ bos_token (`str`, *optional*):
58
+ The beginning of sequence token that can be used for downstream task, was not seen during pretraining. If
59
+ not provided, will default to '<s>' or '<|endoftext|>', depending on model size.
60
+ sp_model_kwargs (`dict`, *optional*):
61
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
62
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
63
+ to set:
64
+
65
+ - `enable_sampling`: Enable subword regularization.
66
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
67
+
68
+ - `nbest_size = {0,1}`: No sampling is performed.
69
+ - `nbest_size > 1`: samples from the nbest_size results.
70
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
71
+ using forward-filtering-and-backward-sampling algorithm.
72
+
73
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
74
+ BPE-dropout.
75
+
76
+ Attributes:
77
+ sp_model (`SentencePieceProcessor`):
78
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
79
+ whitespaces (`set`):
80
+ The whitespaces that are replaced in the whitespace normalization in preprocessing.
81
+ non_printing_characters_re (`Pattern`):
82
+ The compiled regular expression to remove non-printing characters in preprocessing.
83
+ """
84
+
85
+ vocab_files_names = VOCAB_FILES_NAMES
86
+ model_input_names = ["input_ids", "attention_mask"]
87
+
88
+ def __init__(
89
+ self,
90
+ vocab_file,
91
+ do_lower_case=False,
92
+ remove_space=False,
93
+ keep_accents=False,
94
+ pad_token=None,
95
+ unk_token=None,
96
+ eos_token=None,
97
+ bos_token=None,
98
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
99
+ **kwargs,
100
+ ) -> None:
101
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
102
+
103
+ name_or_path = kwargs.get("name_or_path")
104
+ if name_or_path is None:
105
+ logger.warning(
106
+ "name_or_path not provided, will work for all GPTSw3 models except gpt-sw3-7b,"
107
+ " you are testing the model, this can safely be ignored"
108
+ )
109
+ name_or_path = "None"
110
+
111
+ # Default definitions for our 2 tokenizer versions, with None-checks to enable proper testing
112
+ eos_token = "<|endoftext|>" if eos_token is None else eos_token
113
+ unk_token = "<unk>" if unk_token is None else unk_token
114
+ if "gpt-sw3-7b" in name_or_path:
115
+ pad_token = unk_token if pad_token is None else pad_token
116
+ bos_token = eos_token if bos_token is None else bos_token
117
+ else:
118
+ pad_token = "<pad>" if pad_token is None else pad_token
119
+ bos_token = "<s>" if bos_token is None else bos_token
120
+
121
+ self.do_lower_case = do_lower_case
122
+ self.remove_space = remove_space
123
+ self.keep_accents = keep_accents
124
+ self.vocab_file = vocab_file
125
+
126
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
127
+ self.sp_model.Load(vocab_file)
128
+
129
+ # Used for whitespace normalization in input texts
130
+ # fmt : off
131
+ self.whitespaces = {" ", " ", " ", " ", " ", " ", " ", " ", " ", " ", "", "„"}
132
+ # fmt : on
133
+
134
+ # Regular expression to remove non-printing characters (e.g. some unicode control chars) in preprocessing
135
+ self.non_printing_characters_re = re.compile(
136
+ f"[{''.join(map(chr, list(range(0, 9)) + list(range(11, 32)) + list(range(127, 160)) + [160, 173, 8203]))}]"
137
+ )
138
+
139
+ super().__init__(
140
+ do_lower_case=do_lower_case,
141
+ remove_space=remove_space,
142
+ keep_accents=keep_accents,
143
+ bos_token=bos_token,
144
+ eos_token=eos_token,
145
+ unk_token=unk_token,
146
+ pad_token=pad_token,
147
+ sp_model_kwargs=self.sp_model_kwargs,
148
+ **kwargs,
149
+ )
150
+
151
+ # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.__getstate__
152
+ def __getstate__(self):
153
+ state = self.__dict__.copy()
154
+ state["sp_model"] = None
155
+ return state
156
+
157
+ # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.__setstate__
158
+ def __setstate__(self, d):
159
+ self.__dict__ = d
160
+
161
+ # for backward compatibility
162
+ if not hasattr(self, "sp_model_kwargs"):
163
+ self.sp_model_kwargs = {}
164
+
165
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
166
+ self.sp_model.Load(self.vocab_file)
167
+
168
+ @property
169
+ # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.vocab_size
170
+ def vocab_size(self) -> int:
171
+ return len(self.sp_model)
172
+
173
+ def preprocess_text(self, text: str) -> str:
174
+ """
175
+ Returns the preprocessed text. This procedure is identical to what was used when training the tokenizer.
176
+ """
177
+
178
+ # Remove non-printing characters
179
+ text = self.non_printing_characters_re.sub("", text)
180
+
181
+ # Normalize whitespaces
182
+ text = "".join([char if char not in self.whitespaces else " " for char in text])
183
+
184
+ # NFC Unicode normalization
185
+ text = unicodedata.normalize("NFC", text)
186
+ return text
187
+
188
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
189
+ text = self.preprocess_text(text)
190
+ return self.sp_model.encode(text, out_type=str)
191
+
192
+ def _convert_token_to_id(self, token: str) -> int:
193
+ """Converts a token (str) to an id (int) using the vocab."""
194
+ return self.sp_model.PieceToId(token)
195
+
196
+ def _convert_id_to_token(self, index: int) -> str:
197
+ """Converts an index (int) to a token (str) using the vocab."""
198
+ return self.sp_model.IdToPiece(index)
199
+
200
+ @staticmethod
201
+ def clean_up_tokenization(out_string: str) -> str:
202
+ """Returns the input string, this function is overridden to remove the default clean up."""
203
+ return out_string
204
+
205
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
206
+ """Converts a sequence of tokens (strings) to a single string. Special tokens remain intact."""
207
+ current_sub_tokens = []
208
+ out_string = ""
209
+ prev_is_special = False
210
+ for token in tokens:
211
+ # make sure that special tokens are not decoded using sentencepiece model
212
+ if token in self.all_special_tokens:
213
+ # TODO: Check if this is needed, as it ensures that decode(encode(doc)) != doc by adding extra whitespace in the decoded document
214
+ if not prev_is_special:
215
+ out_string += " "
216
+
217
+ out_string += self.sp_model.decode(current_sub_tokens) + token
218
+ prev_is_special = True
219
+ current_sub_tokens = []
220
+ else:
221
+ current_sub_tokens.append(token)
222
+ prev_is_special = False
223
+ out_string += self.sp_model.decode(current_sub_tokens)
224
+
225
+ return out_string
226
+
227
+ # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.get_vocab
228
+ def get_vocab(self) -> Dict[str, int]:
229
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
230
+ vocab.update(self.added_tokens_encoder)
231
+ return vocab
232
+
233
+ # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.save_vocabulary
234
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
235
+ if not os.path.isdir(save_directory):
236
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
237
+ return
238
+ out_vocab_file = os.path.join(
239
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
240
+ )
241
+
242
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
243
+ copyfile(self.vocab_file, out_vocab_file)
244
+ elif not os.path.isfile(self.vocab_file):
245
+ with open(out_vocab_file, "wb") as fi:
246
+ content_spiece_model = self.sp_model.serialized_model_proto()
247
+ fi.write(content_spiece_model)
248
+
249
+ return (out_vocab_file,)
250
+
251
+ def encode_fast(
252
+ self, text: Union[str, List[str]], return_tensors: Union[str, bool] = False
253
+ ) -> Union[List[int], List[List[int]], "torch.Tensor"]:
254
+ """
255
+ Encodes a text or batch of texts to token ids using preprocessing and the raw SP tokenizer. This has reduced
256
+ functionality but is often much faster.
257
+
258
+ Does NOT handle special tokens correctly, these can manually be added as ids afterwards.
259
+
260
+ Does NOT support padding, these can manually be added as ids afterwards.
261
+
262
+ Use default HuggingFace tokenization methods for full functionality.
263
+
264
+ Args:
265
+ text (`str` or `List[str]`): One or several text(s) to convert to token ids.
266
+ return_tensors (`str` or `bool`): Returns PyTorch tensors if set to True or "pt"
267
+
268
+ Returns:
269
+ `List[int]`, `List[List[int]]`, or `torch.Tensor`: The encoded text(s) as token ids.
270
+ """
271
+
272
+ if isinstance(text, str):
273
+ text = self.preprocess_text(text)
274
+ token_ids = self.sp_model.encode(text)
275
+ else:
276
+ text = [self.preprocess_text(t) for t in text]
277
+ token_ids = self.sp_model.encode(text)
278
+
279
+ if return_tensors is True or return_tensors == "pt":
280
+ token_ids = torch.tensor(token_ids)
281
+
282
+ return token_ids
283
+
284
+ def decode_fast(self, token_ids: Union[int, List[int]]) -> str:
285
+ """
286
+ Encodes a text or batch of texts to token ids using preprocessing and the raw SP tokenizer. This has reduced
287
+ functionality but is often much faster.
288
+
289
+ Args:
290
+ token_ids (`int` or `List[int]`): Encoded token or text as token id(s).
291
+
292
+ Returns:
293
+ `str`: Decoded text
294
+ """
295
+
296
+ return self.sp_model.decode(token_ids)
297
+
298
+
299
+ __all__ = ["GPTSw3Tokenizer"]
.venv/lib/python3.11/site-packages/transformers/models/musicgen/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_musicgen import *
22
+ from .modeling_musicgen import *
23
+ from .processing_musicgen import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (813 Bytes). View file
 
.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/configuration_musicgen.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/musicgen/__pycache__/processing_musicgen.cpython-311.pyc ADDED
Binary file (6.39 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/musicgen/configuration_musicgen.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """MusicGen model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+ from ..auto.configuration_auto import AutoConfig
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class MusicgenDecoderConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of an [`MusicgenDecoder`]. It is used to instantiate a
28
+ MusicGen decoder according to the specified arguments, defining the model architecture. Instantiating a
29
+ configuration with the defaults will yield a similar configuration to that of the MusicGen
30
+ [facebook/musicgen-small](https://huggingface.co/facebook/musicgen-small) architecture.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 2048):
38
+ Vocabulary size of the MusicgenDecoder model. Defines the number of different tokens that can be
39
+ represented by the `inputs_ids` passed when calling [`MusicgenDecoder`].
40
+ hidden_size (`int`, *optional*, defaults to 1024):
41
+ Dimensionality of the layers and the pooler layer.
42
+ num_hidden_layers (`int`, *optional*, defaults to 24):
43
+ Number of decoder layers.
44
+ num_attention_heads (`int`, *optional*, defaults to 16):
45
+ Number of attention heads for each attention layer in the Transformer block.
46
+ ffn_dim (`int`, *optional*, defaults to 4096):
47
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
48
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
49
+ The non-linear activation function (function or string) in the decoder and pooler. If string, `"gelu"`,
50
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
51
+ dropout (`float`, *optional*, defaults to 0.1):
52
+ The dropout probability for all fully connected layers in the embeddings, text_encoder, and pooler.
53
+ attention_dropout (`float`, *optional*, defaults to 0.0):
54
+ The dropout ratio for the attention probabilities.
55
+ activation_dropout (`float`, *optional*, defaults to 0.0):
56
+ The dropout ratio for activations inside the fully connected layer.
57
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
58
+ The maximum sequence length that this model might ever be used with. Typically, set this to something large
59
+ just in case (e.g., 512 or 1024 or 2048).
60
+ initializer_factor (`float`, *optional*, defaults to 0.02):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62
+ layerdrop (`float`, *optional*, defaults to 0.0):
63
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
64
+ for more details.
65
+ scale_embedding (`bool`, *optional*, defaults to `False`):
66
+ Scale embeddings by diving by sqrt(hidden_size).
67
+ use_cache (`bool`, *optional*, defaults to `True`):
68
+ Whether the model should return the last key/values attentions (not used by all models)
69
+ num_codebooks (`int`, *optional*, defaults to 4):
70
+ The number of parallel codebooks forwarded to the model.
71
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
72
+ Whether input and output word embeddings should be tied.
73
+ audio_channels (`int`, *optional*, defaults to 1
74
+ Number of channels in the audio data. Either 1 for mono or 2 for stereo. Stereo models generate a separate
75
+ audio stream for the left/right output channels. Mono models generate a single audio stream output.
76
+ """
77
+
78
+ model_type = "musicgen_decoder"
79
+ base_config_key = "decoder_config"
80
+ keys_to_ignore_at_inference = ["past_key_values"]
81
+
82
+ def __init__(
83
+ self,
84
+ vocab_size=2048,
85
+ max_position_embeddings=2048,
86
+ num_hidden_layers=24,
87
+ ffn_dim=4096,
88
+ num_attention_heads=16,
89
+ layerdrop=0.0,
90
+ use_cache=True,
91
+ activation_function="gelu",
92
+ hidden_size=1024,
93
+ dropout=0.1,
94
+ attention_dropout=0.0,
95
+ activation_dropout=0.0,
96
+ initializer_factor=0.02,
97
+ scale_embedding=False,
98
+ num_codebooks=4,
99
+ audio_channels=1,
100
+ pad_token_id=2048,
101
+ bos_token_id=2048,
102
+ eos_token_id=None,
103
+ tie_word_embeddings=False,
104
+ **kwargs,
105
+ ):
106
+ self.vocab_size = vocab_size
107
+ self.max_position_embeddings = max_position_embeddings
108
+ self.hidden_size = hidden_size
109
+ self.ffn_dim = ffn_dim
110
+ self.num_hidden_layers = num_hidden_layers
111
+ self.num_attention_heads = num_attention_heads
112
+ self.dropout = dropout
113
+ self.attention_dropout = attention_dropout
114
+ self.activation_dropout = activation_dropout
115
+ self.activation_function = activation_function
116
+ self.initializer_factor = initializer_factor
117
+ self.layerdrop = layerdrop
118
+ self.use_cache = use_cache
119
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
120
+ self.num_codebooks = num_codebooks
121
+
122
+ if audio_channels not in [1, 2]:
123
+ raise ValueError(f"Expected 1 (mono) or 2 (stereo) audio channels, got {audio_channels} channels.")
124
+ self.audio_channels = audio_channels
125
+
126
+ super().__init__(
127
+ pad_token_id=pad_token_id,
128
+ bos_token_id=bos_token_id,
129
+ eos_token_id=eos_token_id,
130
+ tie_word_embeddings=tie_word_embeddings,
131
+ **kwargs,
132
+ )
133
+
134
+
135
+ class MusicgenConfig(PretrainedConfig):
136
+ r"""
137
+ This is the configuration class to store the configuration of a [`MusicgenModel`]. It is used to instantiate a
138
+ MusicGen model according to the specified arguments, defining the text encoder, audio encoder and MusicGen decoder
139
+ configs.
140
+
141
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
142
+ documentation from [`PretrainedConfig`] for more information.
143
+
144
+ Args:
145
+ kwargs (*optional*):
146
+ Dictionary of keyword arguments. Notably:
147
+
148
+ - **text_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
149
+ defines the text encoder config.
150
+ - **audio_encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
151
+ defines the audio encoder config.
152
+ - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
153
+ the decoder config.
154
+
155
+ Example:
156
+
157
+ ```python
158
+ >>> from transformers import (
159
+ ... MusicgenConfig,
160
+ ... MusicgenDecoderConfig,
161
+ ... T5Config,
162
+ ... EncodecConfig,
163
+ ... MusicgenForConditionalGeneration,
164
+ ... )
165
+
166
+ >>> # Initializing text encoder, audio encoder, and decoder model configurations
167
+ >>> text_encoder_config = T5Config()
168
+ >>> audio_encoder_config = EncodecConfig()
169
+ >>> decoder_config = MusicgenDecoderConfig()
170
+
171
+ >>> configuration = MusicgenConfig.from_sub_models_config(
172
+ ... text_encoder_config, audio_encoder_config, decoder_config
173
+ ... )
174
+
175
+ >>> # Initializing a MusicgenForConditionalGeneration (with random weights) from the facebook/musicgen-small style configuration
176
+ >>> model = MusicgenForConditionalGeneration(configuration)
177
+
178
+ >>> # Accessing the model configuration
179
+ >>> configuration = model.config
180
+ >>> config_text_encoder = model.config.text_encoder
181
+ >>> config_audio_encoder = model.config.audio_encoder
182
+ >>> config_decoder = model.config.decoder
183
+
184
+ >>> # Saving the model, including its configuration
185
+ >>> model.save_pretrained("musicgen-model")
186
+
187
+ >>> # loading model and config from pretrained folder
188
+ >>> musicgen_config = MusicgenConfig.from_pretrained("musicgen-model")
189
+ >>> model = MusicgenForConditionalGeneration.from_pretrained("musicgen-model", config=musicgen_config)
190
+ ```"""
191
+
192
+ model_type = "musicgen"
193
+ sub_configs = {
194
+ "text_encoder": AutoConfig,
195
+ "audio_encoder": AutoConfig,
196
+ "decoder": MusicgenDecoderConfig,
197
+ }
198
+ is_composition = True
199
+
200
+ def __init__(self, **kwargs):
201
+ super().__init__(**kwargs)
202
+ if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
203
+ raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
204
+
205
+ text_encoder_config = kwargs.pop("text_encoder")
206
+ text_encoder_model_type = text_encoder_config.pop("model_type")
207
+
208
+ audio_encoder_config = kwargs.pop("audio_encoder")
209
+ audio_encoder_model_type = audio_encoder_config.pop("model_type")
210
+
211
+ decoder_config = kwargs.pop("decoder")
212
+
213
+ self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
214
+ self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
215
+ self.decoder = MusicgenDecoderConfig(**decoder_config)
216
+ self.is_encoder_decoder = True
217
+
218
+ @classmethod
219
+ def from_sub_models_config(
220
+ cls,
221
+ text_encoder_config: PretrainedConfig,
222
+ audio_encoder_config: PretrainedConfig,
223
+ decoder_config: MusicgenDecoderConfig,
224
+ **kwargs,
225
+ ):
226
+ r"""
227
+ Instantiate a [`MusicgenConfig`] (or a derived class) from text encoder, audio encoder and decoder
228
+ configurations.
229
+
230
+ Returns:
231
+ [`MusicgenConfig`]: An instance of a configuration object
232
+ """
233
+
234
+ return cls(
235
+ text_encoder=text_encoder_config.to_dict(),
236
+ audio_encoder=audio_encoder_config.to_dict(),
237
+ decoder=decoder_config.to_dict(),
238
+ **kwargs,
239
+ )
240
+
241
+ @property
242
+ # This is a property because you might want to change the codec model on the fly
243
+ def sampling_rate(self):
244
+ return self.audio_encoder.sampling_rate
245
+
246
+
247
+ __all__ = ["MusicgenConfig", "MusicgenDecoderConfig"]
.venv/lib/python3.11/site-packages/transformers/models/musicgen/modeling_musicgen.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.11/site-packages/transformers/models/musicgen/processing_musicgen.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Text/audio processor class for MusicGen
17
+ """
18
+
19
+ from typing import List, Optional
20
+
21
+ import numpy as np
22
+
23
+ from ...processing_utils import ProcessorMixin
24
+ from ...utils import to_numpy
25
+
26
+
27
+ class MusicgenProcessor(ProcessorMixin):
28
+ r"""
29
+ Constructs a MusicGen processor which wraps an EnCodec feature extractor and a T5 tokenizer into a single processor
30
+ class.
31
+
32
+ [`MusicgenProcessor`] offers all the functionalities of [`EncodecFeatureExtractor`] and [`TTokenizer`]. See
33
+ [`~MusicgenProcessor.__call__`] and [`~MusicgenProcessor.decode`] for more information.
34
+
35
+ Args:
36
+ feature_extractor (`EncodecFeatureExtractor`):
37
+ An instance of [`EncodecFeatureExtractor`]. The feature extractor is a required input.
38
+ tokenizer (`T5Tokenizer`):
39
+ An instance of [`T5Tokenizer`]. The tokenizer is a required input.
40
+ """
41
+
42
+ feature_extractor_class = "EncodecFeatureExtractor"
43
+ tokenizer_class = ("T5Tokenizer", "T5TokenizerFast")
44
+
45
+ def __init__(self, feature_extractor, tokenizer):
46
+ super().__init__(feature_extractor, tokenizer)
47
+ self.current_processor = self.feature_extractor
48
+ self._in_target_context_manager = False
49
+
50
+ def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
51
+ return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
52
+
53
+ def __call__(self, *args, **kwargs):
54
+ """
55
+ Forwards the `audio` argument to EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`] and the `text`
56
+ argument to [`~T5Tokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
57
+ information.
58
+ """
59
+ # For backward compatibility
60
+ if self._in_target_context_manager:
61
+ return self.current_processor(*args, **kwargs)
62
+
63
+ audio = kwargs.pop("audio", None)
64
+ sampling_rate = kwargs.pop("sampling_rate", None)
65
+ text = kwargs.pop("text", None)
66
+ if len(args) > 0:
67
+ audio = args[0]
68
+ args = args[1:]
69
+
70
+ if audio is None and text is None:
71
+ raise ValueError("You need to specify either an `audio` or `text` input to process.")
72
+
73
+ if text is not None:
74
+ inputs = self.tokenizer(text, **kwargs)
75
+
76
+ if audio is not None:
77
+ audio_inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
78
+
79
+ if audio is None:
80
+ return inputs
81
+
82
+ elif text is None:
83
+ return audio_inputs
84
+
85
+ else:
86
+ inputs["input_values"] = audio_inputs["input_values"]
87
+ if "padding_mask" in audio_inputs:
88
+ inputs["padding_mask"] = audio_inputs["padding_mask"]
89
+ return inputs
90
+
91
+ def batch_decode(self, *args, **kwargs):
92
+ """
93
+ This method is used to decode either batches of audio outputs from the MusicGen model, or batches of token ids
94
+ from the tokenizer. In the case of decoding token ids, this method forwards all its arguments to T5Tokenizer's
95
+ [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information.
96
+ """
97
+ audio_values = kwargs.pop("audio", None)
98
+ padding_mask = kwargs.pop("padding_mask", None)
99
+
100
+ if len(args) > 0:
101
+ audio_values = args[0]
102
+ args = args[1:]
103
+
104
+ if audio_values is not None:
105
+ return self._decode_audio(audio_values, padding_mask=padding_mask)
106
+ else:
107
+ return self.tokenizer.batch_decode(*args, **kwargs)
108
+
109
+ def decode(self, *args, **kwargs):
110
+ """
111
+ This method forwards all its arguments to T5Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
112
+ docstring of this method for more information.
113
+ """
114
+ return self.tokenizer.decode(*args, **kwargs)
115
+
116
+ def _decode_audio(self, audio_values, padding_mask: Optional = None) -> List[np.ndarray]:
117
+ """
118
+ This method strips any padding from the audio values to return a list of numpy audio arrays.
119
+ """
120
+ audio_values = to_numpy(audio_values)
121
+ bsz, channels, seq_len = audio_values.shape
122
+
123
+ if padding_mask is None:
124
+ return list(audio_values)
125
+
126
+ padding_mask = to_numpy(padding_mask)
127
+
128
+ # match the sequence length of the padding mask to the generated audio arrays by padding with the **non-padding**
129
+ # token (so that the generated audio values are **not** treated as padded tokens)
130
+ difference = seq_len - padding_mask.shape[-1]
131
+ padding_value = 1 - self.feature_extractor.padding_value
132
+ padding_mask = np.pad(padding_mask, ((0, 0), (0, difference)), "constant", constant_values=padding_value)
133
+
134
+ audio_values = audio_values.tolist()
135
+ for i in range(bsz):
136
+ sliced_audio = np.asarray(audio_values[i])[
137
+ padding_mask[i][None, :] != self.feature_extractor.padding_value
138
+ ]
139
+ audio_values[i] = sliced_audio.reshape(channels, -1)
140
+
141
+ return audio_values
142
+
143
+
144
+ __all__ = ["MusicgenProcessor"]
.venv/lib/python3.11/site-packages/transformers/models/olmoe/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_olmoe import *
22
+ from .modeling_olmoe import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (763 Bytes). View file
 
.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/configuration_olmoe.cpython-311.pyc ADDED
Binary file (8.47 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/olmoe/__pycache__/modeling_olmoe.cpython-311.pyc ADDED
Binary file (66 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/olmoe/configuration_olmoe.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+ """OLMoE model configuration"""
13
+
14
+ from ...configuration_utils import PretrainedConfig
15
+ from ...modeling_rope_utils import rope_config_validation
16
+
17
+
18
+ class OlmoeConfig(PretrainedConfig):
19
+ r"""
20
+ This is the configuration class to store the configuration of a [`OlmoeModel`]. It is used to instantiate an OLMoE
21
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
22
+ defaults will yield a similar configuration to that of the [allenai/OLMoE-1B-7B-0924](https://huggingface.co/allenai/OLMoE-1B-7B-0924).
23
+
24
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
25
+ documentation from [`PretrainedConfig`] for more information.
26
+
27
+
28
+ Args:
29
+ vocab_size (`int`, *optional*, defaults to 50304):
30
+ Vocabulary size of the OLMoE model. Defines the number of different tokens that can be represented by the
31
+ `inputs_ids` passed when calling [`OlmoeModel`]
32
+ hidden_size (`int`, *optional*, defaults to 2048):
33
+ Dimension of the hidden representations.
34
+ intermediate_size (`int`, *optional*, defaults to 2048):
35
+ Dimension of the MLP representations.
36
+ num_hidden_layers (`int`, *optional*, defaults to 16):
37
+ Number of hidden layers in the Transformer decoder.
38
+ num_attention_heads (`int`, *optional*, defaults to 16):
39
+ Number of attention heads for each attention layer in the Transformer decoder.
40
+ num_key_value_heads (`int`, *optional*):
41
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
42
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
43
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
44
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
45
+ by meanpooling all the original heads within that group. For more details checkout [this
46
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
47
+ `num_attention_heads`.
48
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
49
+ The non-linear activation function (function or string) in the decoder.
50
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
51
+ The maximum sequence length that this model might ever be used with.
52
+ initializer_range (`float`, *optional*, defaults to 0.02):
53
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
54
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
55
+ The epsilon used by the rms normalization layers.
56
+ use_cache (`bool`, *optional*, defaults to `True`):
57
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
58
+ relevant if `config.is_decoder=True`.
59
+ pad_token_id (`int`, *optional*, defaults to 1):
60
+ Padding token id.
61
+ bos_token_id (`int`, *optional*):
62
+ Beginning of stream token id.
63
+ eos_token_id (`int`, *optional*, defaults to 50279):
64
+ End of stream token id.
65
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
66
+ Whether to tie weight embeddings
67
+ rope_theta (`float`, *optional*, defaults to 10000.0):
68
+ The base period of the RoPE embeddings.
69
+ rope_scaling (`Dict`, *optional*):
70
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
71
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
72
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
73
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
74
+ these scaling strategies behave:
75
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
76
+ experimental feature, subject to breaking API changes in future versions.
77
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
78
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
79
+ attention_dropout (`float`, *optional*, defaults to 0.0):
80
+ The dropout ratio for the attention probabilities.
81
+ clip_qkv (`float`, *optional*):
82
+ If not `None`, elements of query, key and value attention states are clipped so that their
83
+ absolute value does not exceed this value.
84
+ num_experts_per_tok (`int`, *optional*, defaults to 8):
85
+ Number of selected experts.
86
+ num_experts (`int`, *optional*, defaults to 64):
87
+ Number of routed experts.
88
+ output_router_logits (`bool`, *optional*, defaults to `False`):
89
+ Whether or not the router logits should be returned by the model. Enabeling this will also
90
+ allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
91
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.01):
92
+ The aux loss factor for the total loss.
93
+ norm_topk_prob (`bool`, *optional*, defaults to `False`):
94
+ Whether to normalize the topk probabilities.
95
+
96
+ ```python
97
+ >>> from transformers import OlmoeModel, OlmoeConfig
98
+
99
+ >>> # Initializing a OLMoE 7B A1B style configuration
100
+ >>> configuration = OlmoeConfig()
101
+
102
+ >>> # Initializing a model from the OLMoE 7B A1B style configuration
103
+ >>> model = OlmoeModel(configuration)
104
+
105
+ >>> # Accessing the model configuration
106
+ >>> configuration = model.config
107
+ ```"""
108
+
109
+ model_type = "olmoe"
110
+ keys_to_ignore_at_inference = ["past_key_values"]
111
+
112
+ def __init__(
113
+ self,
114
+ vocab_size=50304,
115
+ hidden_size=2048,
116
+ intermediate_size=2048,
117
+ num_hidden_layers=16,
118
+ num_attention_heads=16,
119
+ num_key_value_heads=None,
120
+ hidden_act="silu",
121
+ max_position_embeddings=4096,
122
+ initializer_range=0.02,
123
+ rms_norm_eps=1e-05,
124
+ use_cache=True,
125
+ pad_token_id=1,
126
+ bos_token_id=None,
127
+ eos_token_id=50279,
128
+ tie_word_embeddings=False,
129
+ rope_theta=10000.0,
130
+ rope_scaling=None,
131
+ attention_bias=False,
132
+ attention_dropout=0.0,
133
+ clip_qkv=None,
134
+ num_experts_per_tok=8,
135
+ num_experts=64,
136
+ output_router_logits=False,
137
+ router_aux_loss_coef=0.01,
138
+ norm_topk_prob=False,
139
+ **kwargs,
140
+ ):
141
+ self.vocab_size = vocab_size
142
+ self.max_position_embeddings = max_position_embeddings
143
+ self.hidden_size = hidden_size
144
+ self.intermediate_size = intermediate_size
145
+ self.num_hidden_layers = num_hidden_layers
146
+ self.num_attention_heads = num_attention_heads
147
+
148
+ # for backward compatibility
149
+ if num_key_value_heads is None:
150
+ num_key_value_heads = num_attention_heads
151
+
152
+ self.num_key_value_heads = num_key_value_heads
153
+ self.hidden_act = hidden_act
154
+ self.initializer_range = initializer_range
155
+ self.rms_norm_eps = rms_norm_eps
156
+ self.use_cache = use_cache
157
+ self.rope_theta = rope_theta
158
+ self.rope_scaling = rope_scaling
159
+ self.attention_bias = attention_bias
160
+ self.attention_dropout = attention_dropout
161
+ self.clip_qkv = clip_qkv
162
+ self.num_experts_per_tok = num_experts_per_tok
163
+ self.num_experts = num_experts
164
+ self.output_router_logits = output_router_logits
165
+ self.router_aux_loss_coef = router_aux_loss_coef
166
+ self.norm_topk_prob = norm_topk_prob
167
+ # Validate the correctness of rotary position embeddings parameters
168
+ # BC: if there is a 'type' field, move it to 'rope_type'.
169
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
170
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
171
+ rope_config_validation(self)
172
+
173
+ super().__init__(
174
+ pad_token_id=pad_token_id,
175
+ bos_token_id=bos_token_id,
176
+ eos_token_id=eos_token_id,
177
+ tie_word_embeddings=tie_word_embeddings,
178
+ **kwargs,
179
+ )
180
+
181
+
182
+ __all__ = ["OlmoeConfig"]
.venv/lib/python3.11/site-packages/transformers/models/olmoe/modeling_olmoe.py ADDED
@@ -0,0 +1,1299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+ """PyTorch OLMoE model."""
13
+
14
+ import math
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+
22
+ from ...activations import ACT2FN
23
+ from ...cache_utils import Cache, DynamicCache, StaticCache
24
+ from ...generation import GenerationMixin
25
+ from ...modeling_attn_mask_utils import AttentionMaskConverter
26
+ from ...modeling_outputs import (
27
+ MoeCausalLMOutputWithPast,
28
+ MoeModelOutputWithPast,
29
+ )
30
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
31
+ from ...modeling_utils import PreTrainedModel
32
+ from ...pytorch_utils import ALL_LAYERNORM_LAYERS
33
+ from ...utils import (
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ is_flash_attn_2_available,
37
+ is_flash_attn_greater_or_equal_2_10,
38
+ logging,
39
+ replace_return_docstrings,
40
+ )
41
+ from .configuration_olmoe import OlmoeConfig
42
+
43
+
44
+ if is_flash_attn_2_available():
45
+ from ...modeling_flash_attention_utils import _flash_attention_forward
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+ _CONFIG_FOR_DOC = "OlmoeConfig"
51
+
52
+
53
+ # Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
54
+ def load_balancing_loss_func(
55
+ gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
56
+ num_experts: Optional[int] = None,
57
+ top_k=2,
58
+ attention_mask: Optional[torch.Tensor] = None,
59
+ ) -> Union[torch.Tensor, int]:
60
+ r"""
61
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
62
+
63
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
64
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
65
+ experts is too unbalanced.
66
+
67
+ Args:
68
+ gate_logits:
69
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
70
+ shape [batch_size X sequence_length, num_experts].
71
+ num_experts:
72
+ Number of experts
73
+ top_k:
74
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
75
+ parameter.
76
+ attention_mask (`torch.Tensor`, *optional*):
77
+ The attention_mask used in forward function
78
+ shape [batch_size X sequence_length] if not None.
79
+
80
+ Returns:
81
+ The auxiliary loss.
82
+ """
83
+ if gate_logits is None or not isinstance(gate_logits, tuple):
84
+ return 0
85
+
86
+ if isinstance(gate_logits, tuple):
87
+ compute_device = gate_logits[0].device
88
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
89
+
90
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
91
+
92
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
93
+
94
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
95
+
96
+ if attention_mask is None:
97
+ # Compute the percentage of tokens routed to each experts
98
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
99
+
100
+ # Compute the average probability of routing to these experts
101
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
102
+ else:
103
+ batch_size, sequence_length = attention_mask.shape
104
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
105
+
106
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
107
+ expert_attention_mask = (
108
+ attention_mask[None, :, :, None, None]
109
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
110
+ .reshape(-1, top_k, num_experts)
111
+ .to(compute_device)
112
+ )
113
+
114
+ # Compute the percentage of tokens routed to each experts
115
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
116
+ expert_attention_mask, dim=0
117
+ )
118
+
119
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
120
+ router_per_expert_attention_mask = (
121
+ attention_mask[None, :, :, None]
122
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
123
+ .reshape(-1, num_experts)
124
+ .to(compute_device)
125
+ )
126
+
127
+ # Compute the average probability of routing to these experts
128
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
129
+ router_per_expert_attention_mask, dim=0
130
+ )
131
+
132
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
133
+ return overall_loss * num_experts
134
+
135
+
136
+ class OlmoeRMSNorm(nn.Module):
137
+ def __init__(self, hidden_size, eps=1e-5):
138
+ """
139
+ OlmoeRMSNorm is equivalent to T5LayerNorm
140
+ """
141
+ super().__init__()
142
+ self.weight = nn.Parameter(torch.ones(hidden_size))
143
+ self.variance_epsilon = eps
144
+
145
+ def forward(self, hidden_states):
146
+ input_dtype = hidden_states.dtype
147
+ hidden_states = hidden_states.to(torch.float32)
148
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
149
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
150
+ return self.weight * hidden_states.to(input_dtype)
151
+
152
+ def extra_repr(self):
153
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
154
+
155
+
156
+ ALL_LAYERNORM_LAYERS.append(OlmoeRMSNorm)
157
+
158
+
159
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmoe
160
+ class OlmoeRotaryEmbedding(nn.Module):
161
+ def __init__(self, config: OlmoeConfig, device=None):
162
+ super().__init__()
163
+ # BC: "rope_type" was originally "type"
164
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
165
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
166
+ else:
167
+ self.rope_type = "default"
168
+ self.max_seq_len_cached = config.max_position_embeddings
169
+ self.original_max_seq_len = config.max_position_embeddings
170
+
171
+ self.config = config
172
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
173
+
174
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
175
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
176
+ self.original_inv_freq = self.inv_freq
177
+
178
+ def _dynamic_frequency_update(self, position_ids, device):
179
+ """
180
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
181
+ 1 - growing beyond the cached sequence length (allow scaling)
182
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
183
+ """
184
+ seq_len = torch.max(position_ids) + 1
185
+ if seq_len > self.max_seq_len_cached: # growth
186
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
187
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
188
+ self.max_seq_len_cached = seq_len
189
+
190
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
191
+ # This .to() is needed if the model has been moved to a device after being initialized (because
192
+ # the buffer is automatically moved, but not the original copy)
193
+ self.original_inv_freq = self.original_inv_freq.to(device)
194
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
195
+ self.max_seq_len_cached = self.original_max_seq_len
196
+
197
+ @torch.no_grad()
198
+ def forward(self, x, position_ids):
199
+ if "dynamic" in self.rope_type:
200
+ self._dynamic_frequency_update(position_ids, device=x.device)
201
+
202
+ # Core RoPE block
203
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
204
+ position_ids_expanded = position_ids[:, None, :].float()
205
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
206
+ device_type = x.device.type
207
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
208
+ with torch.autocast(device_type=device_type, enabled=False):
209
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
210
+ emb = torch.cat((freqs, freqs), dim=-1)
211
+ cos = emb.cos()
212
+ sin = emb.sin()
213
+
214
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
215
+ cos = cos * self.attention_scaling
216
+ sin = sin * self.attention_scaling
217
+
218
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
219
+
220
+
221
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
222
+ def rotate_half(x):
223
+ """Rotates half the hidden dims of the input."""
224
+ x1 = x[..., : x.shape[-1] // 2]
225
+ x2 = x[..., x.shape[-1] // 2 :]
226
+ return torch.cat((-x2, x1), dim=-1)
227
+
228
+
229
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
230
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
231
+ """Applies Rotary Position Embedding to the query and key tensors.
232
+
233
+ Args:
234
+ q (`torch.Tensor`): The query tensor.
235
+ k (`torch.Tensor`): The key tensor.
236
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
237
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
238
+ position_ids (`torch.Tensor`, *optional*):
239
+ Deprecated and unused.
240
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
241
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
242
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
243
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
244
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
245
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
246
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
247
+ Returns:
248
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
249
+ """
250
+ cos = cos.unsqueeze(unsqueeze_dim)
251
+ sin = sin.unsqueeze(unsqueeze_dim)
252
+ q_embed = (q * cos) + (rotate_half(q) * sin)
253
+ k_embed = (k * cos) + (rotate_half(k) * sin)
254
+ return q_embed, k_embed
255
+
256
+
257
+ # Copied from transformers.models.olmo.modeling_olmo.OlmoMLP with Olmo->Olmoe
258
+ class OlmoeMLP(nn.Module):
259
+ def __init__(self, config):
260
+ super().__init__()
261
+ self.config = config
262
+ self.hidden_size = config.hidden_size
263
+ self.intermediate_size = config.intermediate_size
264
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
265
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
266
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
267
+ self.act_fn = ACT2FN[config.hidden_act]
268
+
269
+ def forward(self, x):
270
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
271
+ return down_proj
272
+
273
+
274
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
275
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
276
+ """
277
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
278
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
279
+ """
280
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
281
+ if n_rep == 1:
282
+ return hidden_states
283
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
284
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
285
+
286
+
287
+ class OlmoeAttention(nn.Module):
288
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
289
+
290
+ def __init__(self, config: OlmoeConfig, layer_idx: Optional[int] = None):
291
+ super().__init__()
292
+ self.config = config
293
+ self.layer_idx = layer_idx
294
+ if layer_idx is None:
295
+ logger.warning_once(
296
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
297
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
298
+ "when creating this class."
299
+ )
300
+
301
+ self.attention_dropout = config.attention_dropout
302
+ self.hidden_size = config.hidden_size
303
+ self.num_heads = config.num_attention_heads
304
+ self.head_dim = self.hidden_size // self.num_heads
305
+ self.num_key_value_heads = config.num_key_value_heads
306
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
307
+ self.max_position_embeddings = config.max_position_embeddings
308
+ self.rope_theta = config.rope_theta
309
+ self.is_causal = True
310
+
311
+ if (self.head_dim * self.num_heads) != self.hidden_size:
312
+ raise ValueError(
313
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
314
+ f" and `num_heads`: {self.num_heads})."
315
+ )
316
+
317
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
318
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
319
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
320
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
321
+ self.q_norm = OlmoeRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
322
+ self.k_norm = OlmoeRMSNorm(
323
+ (self.hidden_size // self.num_heads) * self.num_key_value_heads, eps=config.rms_norm_eps
324
+ )
325
+
326
+ def forward(
327
+ self,
328
+ hidden_states: torch.Tensor,
329
+ attention_mask: Optional[torch.Tensor] = None,
330
+ position_ids: Optional[torch.LongTensor] = None,
331
+ past_key_value: Optional[Cache] = None,
332
+ output_attentions: bool = False,
333
+ use_cache: bool = False,
334
+ cache_position: Optional[torch.LongTensor] = None,
335
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
336
+ **kwargs,
337
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
338
+ bsz, q_len, _ = hidden_states.size()
339
+
340
+ query_states = self.q_norm(self.q_proj(hidden_states))
341
+ key_states = self.k_norm(self.k_proj(hidden_states))
342
+ value_states = self.v_proj(hidden_states)
343
+
344
+ if self.config.clip_qkv is not None:
345
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
346
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
347
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
348
+
349
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
350
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
351
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
352
+
353
+ cos, sin = position_embeddings
354
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
355
+
356
+ if past_key_value is not None:
357
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
358
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
359
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
360
+
361
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
362
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
363
+
364
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
365
+
366
+ if attention_mask is not None: # no matter the length, we just slice it
367
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
368
+ attn_weights = attn_weights + causal_mask
369
+
370
+ # upcast attention to fp32
371
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
372
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
373
+ attn_output = torch.matmul(attn_weights, value_states)
374
+
375
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
376
+ raise ValueError(
377
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
378
+ f" {attn_output.size()}"
379
+ )
380
+
381
+ attn_output = attn_output.transpose(1, 2).contiguous()
382
+
383
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
384
+
385
+ attn_output = self.o_proj(attn_output)
386
+
387
+ if not output_attentions:
388
+ attn_weights = None
389
+
390
+ return attn_output, attn_weights, past_key_value
391
+
392
+
393
+ class OlmoeFlashAttention2(OlmoeAttention):
394
+ """
395
+ OLMoE flash attention module. This module inherits from `OlmoeAttention` as the weights of the module stays
396
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
397
+ flash attention and deal with padding tokens in case the input contains any of them.
398
+ """
399
+
400
+ def __init__(self, *args, **kwargs):
401
+ super().__init__(*args, **kwargs)
402
+
403
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
404
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
405
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
406
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
407
+
408
+ def forward(
409
+ self,
410
+ hidden_states: torch.Tensor,
411
+ attention_mask: Optional[torch.LongTensor] = None,
412
+ position_ids: Optional[torch.LongTensor] = None,
413
+ past_key_value: Optional[Cache] = None,
414
+ output_attentions: bool = False,
415
+ use_cache: bool = False,
416
+ cache_position: Optional[torch.LongTensor] = None,
417
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
418
+ **kwargs,
419
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
420
+ output_attentions = False
421
+
422
+ bsz, q_len, _ = hidden_states.size()
423
+
424
+ query_states = self.q_norm(self.q_proj(hidden_states))
425
+ key_states = self.k_norm(self.k_proj(hidden_states))
426
+ value_states = self.v_proj(hidden_states)
427
+ if self.config.clip_qkv is not None:
428
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
429
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
430
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
431
+
432
+ # Flash attention requires the input to have the shape
433
+ # batch_size x seq_length x head_dim x hidden_dim
434
+ # therefore we just need to keep the original shape
435
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
436
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
437
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
438
+
439
+ cos, sin = position_embeddings
440
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
441
+
442
+ if past_key_value is not None:
443
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
444
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
445
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
446
+
447
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
448
+ # to be able to avoid many of these transpose/reshape/view.
449
+ query_states = query_states.transpose(1, 2)
450
+ key_states = key_states.transpose(1, 2)
451
+ value_states = value_states.transpose(1, 2)
452
+
453
+ dropout_rate = self.attention_dropout if self.training else 0.0
454
+
455
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
456
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
457
+ # cast them back in the correct dtype just to be sure everything works as expected.
458
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
459
+ # in fp32. (OlmoeRMSNorm handles it correctly)
460
+
461
+ input_dtype = query_states.dtype
462
+ if input_dtype == torch.float32:
463
+ if torch.is_autocast_enabled():
464
+ target_dtype = torch.get_autocast_gpu_dtype()
465
+ # Handle the case where the model is quantized
466
+ elif hasattr(self.config, "_pre_quantization_dtype"):
467
+ target_dtype = self.config._pre_quantization_dtype
468
+ else:
469
+ target_dtype = self.q_proj.weight.dtype
470
+
471
+ logger.warning_once(
472
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
473
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
474
+ f" {target_dtype}."
475
+ )
476
+
477
+ query_states = query_states.to(target_dtype)
478
+ key_states = key_states.to(target_dtype)
479
+ value_states = value_states.to(target_dtype)
480
+
481
+ attn_output = _flash_attention_forward(
482
+ query_states,
483
+ key_states,
484
+ value_states,
485
+ attention_mask,
486
+ q_len,
487
+ dropout=dropout_rate,
488
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
489
+ is_causal=self.is_causal,
490
+ )
491
+
492
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
493
+ attn_output = self.o_proj(attn_output)
494
+
495
+ if not output_attentions:
496
+ attn_weights = None
497
+
498
+ return attn_output, attn_weights, past_key_value
499
+
500
+
501
+ class OlmoeSdpaAttention(OlmoeAttention):
502
+ """
503
+ OLMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
504
+ `OlmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
505
+ SDPA API.
506
+ """
507
+
508
+ # Adapted from OlmoeAttention.forward
509
+ def forward(
510
+ self,
511
+ hidden_states: torch.Tensor,
512
+ attention_mask: Optional[torch.Tensor] = None,
513
+ position_ids: Optional[torch.LongTensor] = None,
514
+ past_key_value: Optional[Cache] = None,
515
+ output_attentions: bool = False,
516
+ use_cache: bool = False,
517
+ cache_position: Optional[torch.LongTensor] = None,
518
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
519
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
520
+ if output_attentions:
521
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
522
+ logger.warning_once(
523
+ "OlmoeModel is using OlmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
524
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
525
+ )
526
+ return super().forward(
527
+ hidden_states=hidden_states,
528
+ attention_mask=attention_mask,
529
+ position_ids=position_ids,
530
+ past_key_value=past_key_value,
531
+ output_attentions=output_attentions,
532
+ use_cache=use_cache,
533
+ cache_position=cache_position,
534
+ position_embeddings=position_embeddings,
535
+ )
536
+
537
+ bsz, q_len, _ = hidden_states.size()
538
+
539
+ query_states = self.q_norm(self.q_proj(hidden_states))
540
+ key_states = self.k_norm(self.k_proj(hidden_states))
541
+ value_states = self.v_proj(hidden_states)
542
+
543
+ if self.config.clip_qkv is not None:
544
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
545
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
546
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
547
+
548
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
549
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
550
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
551
+
552
+ cos, sin = position_embeddings
553
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
554
+
555
+ if past_key_value is not None:
556
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
557
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
558
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
559
+
560
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
561
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
562
+
563
+ causal_mask = attention_mask
564
+ # if attention_mask is not None and cache_position is not None:
565
+ if attention_mask is not None:
566
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
567
+
568
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
569
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
570
+ if query_states.device.type == "cuda" and causal_mask is not None:
571
+ query_states = query_states.contiguous()
572
+ key_states = key_states.contiguous()
573
+ value_states = value_states.contiguous()
574
+
575
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
576
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
577
+ is_causal = True if causal_mask is None and q_len > 1 else False
578
+
579
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
580
+ query_states,
581
+ key_states,
582
+ value_states,
583
+ attn_mask=causal_mask,
584
+ dropout_p=self.attention_dropout if self.training else 0.0,
585
+ is_causal=is_causal,
586
+ )
587
+
588
+ attn_output = attn_output.transpose(1, 2).contiguous()
589
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
590
+
591
+ attn_output = self.o_proj(attn_output)
592
+
593
+ return attn_output, None, past_key_value
594
+
595
+
596
+ OLMOE_ATTENTION_CLASSES = {
597
+ "eager": OlmoeAttention,
598
+ "flash_attention_2": OlmoeFlashAttention2,
599
+ "sdpa": OlmoeSdpaAttention,
600
+ }
601
+
602
+
603
+ class OlmoeSparseMoeBlock(nn.Module):
604
+ def __init__(self, config):
605
+ super().__init__()
606
+ self.num_experts = config.num_experts
607
+ self.top_k = config.num_experts_per_tok
608
+ self.norm_topk_prob = config.norm_topk_prob
609
+ self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
610
+ self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])
611
+
612
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
613
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
614
+ hidden_states = hidden_states.view(-1, hidden_dim)
615
+ # router_logits: (batch * sequence_length, n_experts)
616
+ router_logits = self.gate(hidden_states)
617
+
618
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
619
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
620
+ if self.norm_topk_prob:
621
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
622
+ # we cast back to the input dtype
623
+ routing_weights = routing_weights.to(hidden_states.dtype)
624
+
625
+ final_hidden_states = torch.zeros(
626
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
627
+ )
628
+
629
+ # One hot encode the selected experts to create an expert mask
630
+ # this will be used to easily index which expert is going to be selected
631
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
632
+
633
+ # Loop over all available experts in the model and perform the computation on each expert
634
+ for expert_idx in range(self.num_experts):
635
+ expert_layer = self.experts[expert_idx]
636
+ idx, top_x = torch.where(expert_mask[expert_idx])
637
+
638
+ # Index the correct hidden states and compute the expert hidden state for
639
+ # the current expert. We need to make sure to multiply the output hidden
640
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
641
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
642
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
643
+
644
+ # However `index_add_` only support torch tensors for indexing so we'll use
645
+ # the `top_x` tensor here.
646
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
647
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
648
+ return final_hidden_states, router_logits
649
+
650
+
651
+ class OlmoeDecoderLayer(nn.Module):
652
+ def __init__(self, config: OlmoeConfig, layer_idx: int):
653
+ super().__init__()
654
+ self.hidden_size = config.hidden_size
655
+
656
+ self.self_attn = OLMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
657
+
658
+ self.mlp = OlmoeSparseMoeBlock(config)
659
+ self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
660
+ self.post_attention_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
661
+
662
+ def forward(
663
+ self,
664
+ hidden_states: torch.Tensor,
665
+ attention_mask: Optional[torch.Tensor] = None,
666
+ position_ids: Optional[torch.LongTensor] = None,
667
+ past_key_value: Optional[Cache] = None,
668
+ output_attentions: Optional[bool] = False,
669
+ output_router_logits: Optional[bool] = False,
670
+ use_cache: Optional[bool] = False,
671
+ cache_position: Optional[torch.LongTensor] = None,
672
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
673
+ **kwargs,
674
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
675
+ """
676
+ Args:
677
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
678
+ attention_mask (`torch.FloatTensor`, *optional*):
679
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
680
+ query_sequence_length, key_sequence_length)` if default attention is used.
681
+ output_attentions (`bool`, *optional*):
682
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
683
+ returned tensors for more detail.
684
+ output_router_logits (`bool`, *optional*):
685
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss,
686
+ and should not be returned during inference.
687
+ use_cache (`bool`, *optional*):
688
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
689
+ (see `past_key_values`).
690
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
691
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
692
+ Indices depicting the position of the input sequence tokens in the sequence
693
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
694
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
695
+ with `head_dim` being the embedding dimension of each attention head.
696
+ kwargs (`dict`, *optional*):
697
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
698
+ into the model
699
+ """
700
+ residual = hidden_states
701
+
702
+ hidden_states = self.input_layernorm(hidden_states)
703
+
704
+ # Self Attention
705
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
706
+ hidden_states=hidden_states,
707
+ attention_mask=attention_mask,
708
+ position_ids=position_ids,
709
+ past_key_value=past_key_value,
710
+ output_attentions=output_attentions,
711
+ use_cache=use_cache,
712
+ cache_position=cache_position,
713
+ position_embeddings=position_embeddings,
714
+ **kwargs,
715
+ )
716
+ hidden_states = residual + hidden_states
717
+
718
+ # Fully Connected
719
+ residual = hidden_states
720
+ hidden_states = self.post_attention_layernorm(hidden_states)
721
+ hidden_states, router_logits = self.mlp(hidden_states)
722
+ hidden_states = residual + hidden_states
723
+
724
+ outputs = (hidden_states,)
725
+
726
+ if output_attentions:
727
+ outputs += (self_attn_weights,)
728
+
729
+ if use_cache:
730
+ outputs += (present_key_value,)
731
+
732
+ if output_router_logits:
733
+ outputs += (router_logits,)
734
+
735
+ return outputs
736
+
737
+
738
+ OLMOE_START_DOCSTRING = r"""
739
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
740
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
741
+ etc.)
742
+
743
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
744
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
745
+ and behavior.
746
+
747
+ Parameters:
748
+ config ([`OlmoeConfig`]):
749
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
750
+ load the weights associated with the model, only the configuration. Check out the
751
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
752
+ """
753
+
754
+
755
+ @add_start_docstrings(
756
+ "The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
757
+ OLMOE_START_DOCSTRING,
758
+ )
759
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Olmoe
760
+ class OlmoePreTrainedModel(PreTrainedModel):
761
+ config_class = OlmoeConfig
762
+ base_model_prefix = "model"
763
+ supports_gradient_checkpointing = True
764
+ _no_split_modules = ["OlmoeDecoderLayer"]
765
+ _skip_keys_device_placement = ["past_key_values"]
766
+ _supports_flash_attn_2 = True
767
+ _supports_sdpa = True
768
+ _supports_flex_attn = True
769
+ _supports_cache_class = True
770
+ _supports_quantized_cache = True
771
+ _supports_static_cache = True
772
+
773
+ def _init_weights(self, module):
774
+ std = self.config.initializer_range
775
+ if isinstance(module, nn.Linear):
776
+ module.weight.data.normal_(mean=0.0, std=std)
777
+ if module.bias is not None:
778
+ module.bias.data.zero_()
779
+ elif isinstance(module, nn.Embedding):
780
+ module.weight.data.normal_(mean=0.0, std=std)
781
+ if module.padding_idx is not None:
782
+ module.weight.data[module.padding_idx].zero_()
783
+
784
+
785
+ OLMOE_INPUTS_DOCSTRING = r"""
786
+ Args:
787
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
788
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
789
+ it.
790
+
791
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
792
+ [`PreTrainedTokenizer.__call__`] for details.
793
+
794
+ [What are input IDs?](../glossary#input-ids)
795
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
796
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
797
+
798
+ - 1 for tokens that are **not masked**,
799
+ - 0 for tokens that are **masked**.
800
+
801
+ [What are attention masks?](../glossary#attention-mask)
802
+
803
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
804
+ [`PreTrainedTokenizer.__call__`] for details.
805
+
806
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
807
+ `past_key_values`).
808
+
809
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
810
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
811
+ information on the default strategy.
812
+
813
+ - 1 indicates the head is **not masked**,
814
+ - 0 indicates the head is **masked**.
815
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
816
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
817
+ config.n_positions - 1]`.
818
+
819
+ [What are position IDs?](../glossary#position-ids)
820
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
821
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
822
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
823
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
824
+
825
+ Two formats are allowed:
826
+ - a [`~cache_utils.Cache`] instance;
827
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
828
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
829
+ cache format.
830
+
831
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
832
+ legacy cache format will be returned.
833
+
834
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
835
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
836
+ of shape `(batch_size, sequence_length)`.
837
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
838
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
839
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
840
+ model's internal embedding lookup matrix.
841
+ use_cache (`bool`, *optional*):
842
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
843
+ `past_key_values`).
844
+ output_attentions (`bool`, *optional*):
845
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
846
+ tensors for more detail.
847
+ output_hidden_states (`bool`, *optional*):
848
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
849
+ more detail.
850
+ output_router_logits (`bool`, *optional*):
851
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
852
+ should not be returned during inference.
853
+ return_dict (`bool`, *optional*):
854
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
855
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
856
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
857
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
858
+ the complete sequence length.
859
+ """
860
+
861
+
862
+ @add_start_docstrings(
863
+ "The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
864
+ OLMOE_START_DOCSTRING,
865
+ )
866
+ # TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe
867
+ class OlmoeModel(OlmoePreTrainedModel):
868
+ """
869
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoeDecoderLayer`]
870
+
871
+ Args:
872
+ config: OlmoeConfig
873
+ """
874
+
875
+ def __init__(self, config: OlmoeConfig):
876
+ super().__init__(config)
877
+ self.padding_idx = config.pad_token_id
878
+ self.vocab_size = config.vocab_size
879
+
880
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
881
+ self.layers = nn.ModuleList(
882
+ [OlmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
883
+ )
884
+ self.norm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
885
+ self.rotary_emb = OlmoeRotaryEmbedding(config=config)
886
+ self.gradient_checkpointing = False
887
+
888
+ # Initialize weights and apply final processing
889
+ self.post_init()
890
+
891
+ def get_input_embeddings(self):
892
+ return self.embed_tokens
893
+
894
+ def set_input_embeddings(self, value):
895
+ self.embed_tokens = value
896
+
897
+ @add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING)
898
+ # Ignore copy
899
+ def forward(
900
+ self,
901
+ input_ids: torch.LongTensor = None,
902
+ attention_mask: Optional[torch.Tensor] = None,
903
+ position_ids: Optional[torch.LongTensor] = None,
904
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
905
+ inputs_embeds: Optional[torch.FloatTensor] = None,
906
+ use_cache: Optional[bool] = None,
907
+ output_attentions: Optional[bool] = None,
908
+ output_hidden_states: Optional[bool] = None,
909
+ output_router_logits: Optional[bool] = None,
910
+ return_dict: Optional[bool] = None,
911
+ cache_position: Optional[torch.LongTensor] = None,
912
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
913
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
914
+ output_router_logits = (
915
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
916
+ )
917
+ output_hidden_states = (
918
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
919
+ )
920
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
921
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
922
+
923
+ if (input_ids is None) ^ (inputs_embeds is not None):
924
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
925
+
926
+ if self.gradient_checkpointing and self.training and use_cache:
927
+ logger.warning_once(
928
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
929
+ )
930
+ use_cache = False
931
+
932
+ if inputs_embeds is None:
933
+ inputs_embeds = self.embed_tokens(input_ids)
934
+
935
+ # kept for BC (non `Cache` `past_key_values` inputs)
936
+ return_legacy_cache = False
937
+ if use_cache and not isinstance(past_key_values, Cache):
938
+ return_legacy_cache = True
939
+ if past_key_values is None:
940
+ past_key_values = DynamicCache()
941
+ else:
942
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
943
+ logger.warning_once(
944
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
945
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
946
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
947
+ )
948
+
949
+ if cache_position is None:
950
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
951
+ cache_position = torch.arange(
952
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
953
+ )
954
+ if position_ids is None:
955
+ position_ids = cache_position.unsqueeze(0)
956
+
957
+ causal_mask = self._update_causal_mask(
958
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
959
+ )
960
+
961
+ # embed positions
962
+ hidden_states = inputs_embeds
963
+
964
+ # create position embeddings to be shared across the decoder layers
965
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
966
+
967
+ # decoder layers
968
+ all_hidden_states = () if output_hidden_states else None
969
+ all_self_attns = () if output_attentions else None
970
+ all_router_logits = () if output_router_logits else None
971
+ next_decoder_cache = None
972
+
973
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
974
+ if output_hidden_states:
975
+ all_hidden_states += (hidden_states,)
976
+
977
+ if self.gradient_checkpointing and self.training:
978
+ layer_outputs = self._gradient_checkpointing_func(
979
+ decoder_layer.__call__,
980
+ hidden_states,
981
+ causal_mask,
982
+ position_ids,
983
+ past_key_values,
984
+ output_attentions,
985
+ output_router_logits,
986
+ use_cache,
987
+ cache_position,
988
+ position_embeddings,
989
+ )
990
+ else:
991
+ layer_outputs = decoder_layer(
992
+ hidden_states,
993
+ attention_mask=causal_mask,
994
+ position_ids=position_ids,
995
+ past_key_value=past_key_values,
996
+ output_attentions=output_attentions,
997
+ output_router_logits=output_router_logits,
998
+ use_cache=use_cache,
999
+ cache_position=cache_position,
1000
+ position_embeddings=position_embeddings,
1001
+ )
1002
+
1003
+ hidden_states = layer_outputs[0]
1004
+
1005
+ if use_cache:
1006
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1007
+
1008
+ if output_attentions:
1009
+ all_self_attns += (layer_outputs[1],)
1010
+
1011
+ if output_router_logits and layer_outputs[-1] is not None:
1012
+ all_router_logits += (layer_outputs[-1],)
1013
+
1014
+ hidden_states = self.norm(hidden_states)
1015
+
1016
+ # add hidden states from the last decoder layer
1017
+ if output_hidden_states:
1018
+ all_hidden_states += (hidden_states,)
1019
+
1020
+ next_cache = next_decoder_cache if use_cache else None
1021
+ if return_legacy_cache:
1022
+ next_cache = next_cache.to_legacy_cache()
1023
+
1024
+ if not return_dict:
1025
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1026
+ return MoeModelOutputWithPast(
1027
+ last_hidden_state=hidden_states,
1028
+ past_key_values=next_cache,
1029
+ hidden_states=all_hidden_states,
1030
+ attentions=all_self_attns,
1031
+ router_logits=all_router_logits,
1032
+ )
1033
+
1034
+ def _update_causal_mask(
1035
+ self,
1036
+ attention_mask: torch.Tensor,
1037
+ input_tensor: torch.Tensor,
1038
+ cache_position: torch.Tensor,
1039
+ past_key_values: Cache,
1040
+ output_attentions: bool,
1041
+ ):
1042
+ if self.config._attn_implementation == "flash_attention_2":
1043
+ if attention_mask is not None and 0.0 in attention_mask:
1044
+ return attention_mask
1045
+ return None
1046
+
1047
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1048
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1049
+ # to infer the attention mask.
1050
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1051
+ using_static_cache = isinstance(past_key_values, StaticCache)
1052
+
1053
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1054
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1055
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1056
+ attention_mask,
1057
+ inputs_embeds=input_tensor,
1058
+ past_key_values_length=past_seen_tokens,
1059
+ is_training=self.training,
1060
+ ):
1061
+ return None
1062
+
1063
+ dtype, device = input_tensor.dtype, input_tensor.device
1064
+ sequence_length = input_tensor.shape[1]
1065
+ if using_static_cache:
1066
+ target_length = past_key_values.get_max_cache_shape()
1067
+ else:
1068
+ target_length = (
1069
+ attention_mask.shape[-1]
1070
+ if isinstance(attention_mask, torch.Tensor)
1071
+ else past_seen_tokens + sequence_length + 1
1072
+ )
1073
+
1074
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1075
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1076
+ attention_mask,
1077
+ sequence_length=sequence_length,
1078
+ target_length=target_length,
1079
+ dtype=dtype,
1080
+ device=device,
1081
+ cache_position=cache_position,
1082
+ batch_size=input_tensor.shape[0],
1083
+ )
1084
+
1085
+ if (
1086
+ self.config._attn_implementation == "sdpa"
1087
+ and attention_mask is not None
1088
+ and attention_mask.device.type == "cuda"
1089
+ and not output_attentions
1090
+ ):
1091
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1092
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1093
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1094
+ min_dtype = torch.finfo(dtype).min
1095
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1096
+
1097
+ return causal_mask
1098
+
1099
+ @staticmethod
1100
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1101
+ attention_mask: torch.Tensor,
1102
+ sequence_length: int,
1103
+ target_length: int,
1104
+ dtype: torch.dtype,
1105
+ device: torch.device,
1106
+ cache_position: torch.Tensor,
1107
+ batch_size: int,
1108
+ **kwargs,
1109
+ ):
1110
+ """
1111
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1112
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1113
+
1114
+ Args:
1115
+ attention_mask (`torch.Tensor`):
1116
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1117
+ `(batch_size, 1, query_length, key_value_length)`.
1118
+ sequence_length (`int`):
1119
+ The sequence length being processed.
1120
+ target_length (`int`):
1121
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1122
+ to account for the 0 padding, the part of the cache that is not filled yet.
1123
+ dtype (`torch.dtype`):
1124
+ The dtype to use for the 4D attention mask.
1125
+ device (`torch.device`):
1126
+ The device to plcae the 4D attention mask on.
1127
+ cache_position (`torch.Tensor`):
1128
+ Indices depicting the position of the input sequence tokens in the sequence.
1129
+ batch_size (`torch.Tensor`):
1130
+ Batch size.
1131
+ """
1132
+ if attention_mask is not None and attention_mask.dim() == 4:
1133
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1134
+ causal_mask = attention_mask
1135
+ else:
1136
+ min_dtype = torch.finfo(dtype).min
1137
+ causal_mask = torch.full(
1138
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1139
+ )
1140
+ if sequence_length != 1:
1141
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1142
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1143
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1144
+ if attention_mask is not None:
1145
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1146
+ mask_length = attention_mask.shape[-1]
1147
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1148
+ padding_mask = padding_mask == 0
1149
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1150
+ padding_mask, min_dtype
1151
+ )
1152
+
1153
+ return causal_mask
1154
+
1155
+
1156
+ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
1157
+ _tied_weights_keys = ["lm_head.weight"]
1158
+
1159
+ def __init__(self, config):
1160
+ super().__init__(config)
1161
+ self.model = OlmoeModel(config)
1162
+ self.vocab_size = config.vocab_size
1163
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1164
+
1165
+ self.router_aux_loss_coef = config.router_aux_loss_coef
1166
+ self.num_experts = config.num_experts
1167
+ self.num_experts_per_tok = config.num_experts_per_tok
1168
+ # Initialize weights and apply final processing
1169
+ self.post_init()
1170
+
1171
+ def get_input_embeddings(self):
1172
+ return self.model.embed_tokens
1173
+
1174
+ def set_input_embeddings(self, value):
1175
+ self.model.embed_tokens = value
1176
+
1177
+ def get_output_embeddings(self):
1178
+ return self.lm_head
1179
+
1180
+ def set_output_embeddings(self, new_embeddings):
1181
+ self.lm_head = new_embeddings
1182
+
1183
+ def set_decoder(self, decoder):
1184
+ self.model = decoder
1185
+
1186
+ def get_decoder(self):
1187
+ return self.model
1188
+
1189
+ @add_start_docstrings_to_model_forward(OLMOE_INPUTS_DOCSTRING)
1190
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1191
+ def forward(
1192
+ self,
1193
+ input_ids: torch.LongTensor = None,
1194
+ attention_mask: Optional[torch.Tensor] = None,
1195
+ position_ids: Optional[torch.LongTensor] = None,
1196
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1197
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1198
+ labels: Optional[torch.LongTensor] = None,
1199
+ use_cache: Optional[bool] = None,
1200
+ output_attentions: Optional[bool] = None,
1201
+ output_hidden_states: Optional[bool] = None,
1202
+ output_router_logits: Optional[bool] = None,
1203
+ return_dict: Optional[bool] = None,
1204
+ cache_position: Optional[torch.LongTensor] = None,
1205
+ num_logits_to_keep: int = 0,
1206
+ **loss_kwargs,
1207
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1208
+ r"""
1209
+ Args:
1210
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1211
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1212
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1213
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1214
+
1215
+ num_logits_to_keep (`int`, *optional*):
1216
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1217
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1218
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1219
+
1220
+ Returns:
1221
+
1222
+ Example:
1223
+
1224
+ ```python
1225
+ >>> from transformers import AutoTokenizer, OlmoeForCausalLM
1226
+
1227
+ >>> model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
1228
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")
1229
+
1230
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1231
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1232
+
1233
+ >>> # Generate
1234
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1235
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1236
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
1237
+ ```
1238
+ """
1239
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1240
+ output_router_logits = (
1241
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1242
+ )
1243
+ output_hidden_states = (
1244
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1245
+ )
1246
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1247
+
1248
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1249
+ outputs = self.model(
1250
+ input_ids=input_ids,
1251
+ attention_mask=attention_mask,
1252
+ position_ids=position_ids,
1253
+ past_key_values=past_key_values,
1254
+ inputs_embeds=inputs_embeds,
1255
+ use_cache=use_cache,
1256
+ output_attentions=output_attentions,
1257
+ output_hidden_states=output_hidden_states,
1258
+ output_router_logits=output_router_logits,
1259
+ return_dict=return_dict,
1260
+ cache_position=cache_position,
1261
+ )
1262
+
1263
+ hidden_states = outputs[0]
1264
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1265
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1266
+
1267
+ loss = None
1268
+ if labels is not None:
1269
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
1270
+
1271
+ aux_loss = None
1272
+ if output_router_logits:
1273
+ aux_loss = load_balancing_loss_func(
1274
+ outputs.router_logits if return_dict else outputs[-1],
1275
+ self.num_experts,
1276
+ self.num_experts_per_tok,
1277
+ attention_mask,
1278
+ )
1279
+ if labels is not None:
1280
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
1281
+
1282
+ if not return_dict:
1283
+ output = (logits,) + outputs[1:]
1284
+ if output_router_logits:
1285
+ output = (aux_loss,) + output
1286
+ return (loss,) + output if loss is not None else output
1287
+
1288
+ return MoeCausalLMOutputWithPast(
1289
+ loss=loss,
1290
+ aux_loss=aux_loss,
1291
+ logits=logits,
1292
+ past_key_values=outputs.past_key_values,
1293
+ hidden_states=outputs.hidden_states,
1294
+ attentions=outputs.attentions,
1295
+ router_logits=outputs.router_logits,
1296
+ )
1297
+
1298
+
1299
+ __all__ = ["OlmoeForCausalLM", "OlmoeModel", "OlmoePreTrainedModel"]
.venv/lib/python3.11/site-packages/transformers/models/pegasus/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_pegasus import *
22
+ from .modeling_flax_pegasus import *
23
+ from .modeling_pegasus import *
24
+ from .modeling_tf_pegasus import *
25
+ from .tokenization_pegasus import *
26
+ from .tokenization_pegasus_fast import *
27
+ else:
28
+ import sys
29
+
30
+ _file = globals()["__file__"]
31
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (942 Bytes). View file
 
.venv/lib/python3.11/site-packages/transformers/models/pegasus/__pycache__/configuration_pegasus.cpython-311.pyc ADDED
Binary file (7.29 kB). View file